pytorch

Форк
0
/
embedding_lookup_idx_avx2.cc 
4487 строк · 181.0 Кб
1
//// --------------------------
2
//// ATTENTION:
3
//// THIS CODE IS AUTOGENERATED
4
//// BY hp_emblookup_codegen.py
5
//// DO NOT MODIFY!!!
6
//// --------------------------
7

8
#include <c10/util/Half.h>
9
#include <c10/util/BFloat16.h>
10
#include <immintrin.h>
11
namespace caffe2 {
12

13
template <bool IS_WEIGHT_POSITIONAL>
14
static bool EmbeddingLookupIdx_int32_t_float_float__avx2_fma(
15
    const int64_t block_size,
16
    const int64_t output_size,
17
    const int64_t index_size,
18
    const int64_t data_size,
19
    const float* input,
20
    const int* indices,
21
    const int* offsets,
22
    const float* weights,
23
    const float* scale_bias,
24
    bool normalize_by_lengths,
25
    float* out) {
26
  const int prefdist_T0 = 16;
27
  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
28
  const int fused_block_size = block_size + 0;
29
  int64_t dataInd = 0;
30
  if (block_size == 128) {
31
    // unrolling 16 times
32
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
33
      float* op = &out[rangeIndex * block_size];
34
      __m256 vop0 = _mm256_setzero_ps();
35
      __m256 vop8 = _mm256_setzero_ps();
36
      __m256 vop16 = _mm256_setzero_ps();
37
      __m256 vop24 = _mm256_setzero_ps();
38
      __m256 vop32 = _mm256_setzero_ps();
39
      __m256 vop40 = _mm256_setzero_ps();
40
      __m256 vop48 = _mm256_setzero_ps();
41
      __m256 vop56 = _mm256_setzero_ps();
42
      __m256 vop64 = _mm256_setzero_ps();
43
      __m256 vop72 = _mm256_setzero_ps();
44
      __m256 vop80 = _mm256_setzero_ps();
45
      __m256 vop88 = _mm256_setzero_ps();
46
      __m256 vop96 = _mm256_setzero_ps();
47
      __m256 vop104 = _mm256_setzero_ps();
48
      __m256 vop112 = _mm256_setzero_ps();
49
      __m256 vop120 = _mm256_setzero_ps();
50
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
51
        return false;
52
      }
53
      int64_t end_offset = offsets[rangeIndex + 1];
54
      int64_t length = end_offset - offsets[rangeIndex];
55
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
56
           ++dataInd) {
57
        const int idx = indices[dataInd];
58
        if (idx < 0 || idx >= data_size) {
59
          return false;
60
        }
61
        float wgt = 1.f;
62
        if (weights) {
63
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
64
        }
65
        __m256 vwgt = _mm256_set1_ps(wgt);
66
        const float* ip = &input[idx * fused_block_size];
67
        const int next_T0 = (dataInd < index_size - prefdist_T0)
68
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
69
            ? (dataInd + prefdist_T0)
70
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
71
            : dataInd;
72
        const int idx_pref_T0 = indices[next_T0];
73
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
74
          return false;
75
        }
76
        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
77
        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
78
        _mm_prefetch(
79
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
80
        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
81
        // skip unnecessary prefetch of (&ip_next_T0[8])
82
        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
83
        _mm_prefetch(
84
            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
85
        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
86
        // skip unnecessary prefetch of (&ip_next_T0[24])
87
        vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
88
        _mm_prefetch(
89
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
90
        vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
91
        // skip unnecessary prefetch of (&ip_next_T0[40])
92
        vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
93
        _mm_prefetch(
94
            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
95
        vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
96
        // skip unnecessary prefetch of (&ip_next_T0[56])
97
        vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
98
        _mm_prefetch(
99
            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
100
        vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
101
        // skip unnecessary prefetch of (&ip_next_T0[72])
102
        vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
103
        _mm_prefetch(
104
            reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
105
        vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
106
        // skip unnecessary prefetch of (&ip_next_T0[88])
107
        vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
108
        _mm_prefetch(
109
            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
110
        vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
111
        // skip unnecessary prefetch of (&ip_next_T0[104])
112
        vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
113
        _mm_prefetch(
114
            reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
115
        vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
116
        // skip unnecessary prefetch of (&ip_next_T0[120])
117
      }
118
      if (!normalize_by_lengths || length == 0) {
119
        _mm256_storeu_ps(&op[0], vop0);
120
        _mm256_storeu_ps(&op[8], vop8);
121
        _mm256_storeu_ps(&op[16], vop16);
122
        _mm256_storeu_ps(&op[24], vop24);
123
        _mm256_storeu_ps(&op[32], vop32);
124
        _mm256_storeu_ps(&op[40], vop40);
125
        _mm256_storeu_ps(&op[48], vop48);
126
        _mm256_storeu_ps(&op[56], vop56);
127
        _mm256_storeu_ps(&op[64], vop64);
128
        _mm256_storeu_ps(&op[72], vop72);
129
        _mm256_storeu_ps(&op[80], vop80);
130
        _mm256_storeu_ps(&op[88], vop88);
131
        _mm256_storeu_ps(&op[96], vop96);
132
        _mm256_storeu_ps(&op[104], vop104);
133
        _mm256_storeu_ps(&op[112], vop112);
134
        _mm256_storeu_ps(&op[120], vop120);
135
      } else {
136
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
137
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
138
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
139
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
140
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
141
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
142
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
143
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
144
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
145
        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
146
        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
147
        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
148
        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
149
        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
150
        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
151
        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
152
        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
153
      }
154
    }
155
  } else if (block_size == 64) {
156
    // unrolling 8 times
157
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
158
      float* op = &out[rangeIndex * block_size];
159
      __m256 vop0 = _mm256_setzero_ps();
160
      __m256 vop8 = _mm256_setzero_ps();
161
      __m256 vop16 = _mm256_setzero_ps();
162
      __m256 vop24 = _mm256_setzero_ps();
163
      __m256 vop32 = _mm256_setzero_ps();
164
      __m256 vop40 = _mm256_setzero_ps();
165
      __m256 vop48 = _mm256_setzero_ps();
166
      __m256 vop56 = _mm256_setzero_ps();
167
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
168
        return false;
169
      }
170
      int64_t end_offset = offsets[rangeIndex + 1];
171
      int64_t length = end_offset - offsets[rangeIndex];
172
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
173
           ++dataInd) {
174
        const int idx = indices[dataInd];
175
        if (idx < 0 || idx >= data_size) {
176
          return false;
177
        }
178
        float wgt = 1.f;
179
        if (weights) {
180
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
181
        }
182
        __m256 vwgt = _mm256_set1_ps(wgt);
183
        const float* ip = &input[idx * fused_block_size];
184
        const int next_T0 = (dataInd < index_size - prefdist_T0)
185
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
186
            ? (dataInd + prefdist_T0)
187
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
188
            : dataInd;
189
        const int idx_pref_T0 = indices[next_T0];
190
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
191
          return false;
192
        }
193
        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
194
        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
195
        _mm_prefetch(
196
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
197
        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
198
        // skip unnecessary prefetch of (&ip_next_T0[8])
199
        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
200
        _mm_prefetch(
201
            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
202
        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
203
        // skip unnecessary prefetch of (&ip_next_T0[24])
204
        vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
205
        _mm_prefetch(
206
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
207
        vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
208
        // skip unnecessary prefetch of (&ip_next_T0[40])
209
        vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
210
        _mm_prefetch(
211
            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
212
        vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
213
        // skip unnecessary prefetch of (&ip_next_T0[56])
214
      }
215
      if (!normalize_by_lengths || length == 0) {
216
        _mm256_storeu_ps(&op[0], vop0);
217
        _mm256_storeu_ps(&op[8], vop8);
218
        _mm256_storeu_ps(&op[16], vop16);
219
        _mm256_storeu_ps(&op[24], vop24);
220
        _mm256_storeu_ps(&op[32], vop32);
221
        _mm256_storeu_ps(&op[40], vop40);
222
        _mm256_storeu_ps(&op[48], vop48);
223
        _mm256_storeu_ps(&op[56], vop56);
224
      } else {
225
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
226
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
227
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
228
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
229
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
230
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
231
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
232
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
233
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
234
      }
235
    }
236
  } else if (block_size == 32) {
237
    // unrolling 4 times
238
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
239
      float* op = &out[rangeIndex * block_size];
240
      __m256 vop0 = _mm256_setzero_ps();
241
      __m256 vop8 = _mm256_setzero_ps();
242
      __m256 vop16 = _mm256_setzero_ps();
243
      __m256 vop24 = _mm256_setzero_ps();
244
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
245
        return false;
246
      }
247
      int64_t end_offset = offsets[rangeIndex + 1];
248
      int64_t length = end_offset - offsets[rangeIndex];
249
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
250
           ++dataInd) {
251
        const int idx = indices[dataInd];
252
        if (idx < 0 || idx >= data_size) {
253
          return false;
254
        }
255
        float wgt = 1.f;
256
        if (weights) {
257
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
258
        }
259
        __m256 vwgt = _mm256_set1_ps(wgt);
260
        const float* ip = &input[idx * fused_block_size];
261
        const int next_T0 = (dataInd < index_size - prefdist_T0)
262
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
263
            ? (dataInd + prefdist_T0)
264
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
265
            : dataInd;
266
        const int idx_pref_T0 = indices[next_T0];
267
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
268
          return false;
269
        }
270
        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
271
        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
272
        _mm_prefetch(
273
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
274
        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
275
        // skip unnecessary prefetch of (&ip_next_T0[8])
276
        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
277
        _mm_prefetch(
278
            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
279
        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
280
        // skip unnecessary prefetch of (&ip_next_T0[24])
281
      }
282
      if (!normalize_by_lengths || length == 0) {
283
        _mm256_storeu_ps(&op[0], vop0);
284
        _mm256_storeu_ps(&op[8], vop8);
285
        _mm256_storeu_ps(&op[16], vop16);
286
        _mm256_storeu_ps(&op[24], vop24);
287
      } else {
288
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
289
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
290
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
291
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
292
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
293
      }
294
    }
295
  } else if (block_size == 16) {
296
    // unrolling 2 times
297
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
298
      float* op = &out[rangeIndex * block_size];
299
      __m256 vop0 = _mm256_setzero_ps();
300
      __m256 vop8 = _mm256_setzero_ps();
301
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
302
        return false;
303
      }
304
      int64_t end_offset = offsets[rangeIndex + 1];
305
      int64_t length = end_offset - offsets[rangeIndex];
306
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
307
           ++dataInd) {
308
        const int idx = indices[dataInd];
309
        if (idx < 0 || idx >= data_size) {
310
          return false;
311
        }
312
        float wgt = 1.f;
313
        if (weights) {
314
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
315
        }
316
        __m256 vwgt = _mm256_set1_ps(wgt);
317
        const float* ip = &input[idx * fused_block_size];
318
        const int next_T0 = (dataInd < index_size - prefdist_T0)
319
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
320
            ? (dataInd + prefdist_T0)
321
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
322
            : dataInd;
323
        const int idx_pref_T0 = indices[next_T0];
324
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
325
          return false;
326
        }
327
        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
328
        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
329
        _mm_prefetch(
330
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
331
        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
332
        // skip unnecessary prefetch of (&ip_next_T0[8])
333
      }
334
      if (!normalize_by_lengths || length == 0) {
335
        _mm256_storeu_ps(&op[0], vop0);
336
        _mm256_storeu_ps(&op[8], vop8);
337
      } else {
338
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
339
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
340
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
341
      }
342
    }
343
  } else {
344
    // generic code
345
    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
346
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
347
      float* op = &out[rangeIndex * block_size];
348
      int64_t j = 0;
349
      for (; j + 8 <= block_size; j += 8) {
350
        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
351
      }
352
      for (; j < block_size; j++) {
353
        op[j] = 0.0f;
354
      }
355
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
356
        return false;
357
      }
358
      int64_t end_offset = offsets[rangeIndex + 1];
359
      int64_t length = end_offset - offsets[rangeIndex];
360
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
361
           ++dataInd) {
362
        const int idx = indices[dataInd];
363
        if (idx < 0 || idx >= data_size) {
364
          return false;
365
        }
366
        float wgt = 1.f;
367
        if (weights) {
368
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
369
        }
370
        __m256 vwgt = _mm256_set1_ps(wgt);
371
        const float* ip = &input[idx * fused_block_size];
372
        const int next_T0 = (dataInd < index_size - prefdist_T0)
373
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
374
            ? (dataInd + prefdist_T0)
375
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
376
            : dataInd;
377
        const int idx_pref_T0 = indices[next_T0];
378
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
379
          return false;
380
        }
381
        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
382
        j = 0;
383
        for (; j + 8 <= block_size; j += 8) {
384
          _mm256_storeu_ps(
385
              &op[j],
386
              _mm256_fmadd_ps(
387
                  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
388
          _mm_prefetch(
389
              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
390
        }
391
        for (; j < block_size; j++) {
392
          op[j] = std::fma(wgt, ip[j], op[j]);
393
        }
394
      }
395
      if (normalize_by_lengths && length) {
396
        float len_inv = 1.0f / length;
397
        __m256 vlen_inv = _mm256_set1_ps(len_inv);
398
        j = 0;
399
        for (; j + 8 <= block_size; j += 8) {
400
          _mm256_storeu_ps(
401
              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
402
        }
403
        for (; j < block_size; j++) {
404
          op[j] = len_inv * op[j];
405
        }
406
      }
407
    }
408
  }
409
  return dataInd == index_size;
410
}
411
bool EmbeddingLookupIdx_int32_t_float_float_false__avx2_fma(
412
    const int64_t block_size,
413
    const int64_t output_size,
414
    const int64_t index_size,
415
    const int64_t data_size,
416
    const float* input,
417
    const int* indices,
418
    const int* offsets,
419
    const float* weights,
420
    const float* scale_bias,
421
    bool normalize_by_lengths,
422
    float* out) {
423
  return EmbeddingLookupIdx_int32_t_float_float__avx2_fma<false>(
424
      block_size,
425
      output_size,
426
      index_size,
427
      data_size,
428
      input,
429
      indices,
430
      offsets,
431
      weights,
432
      scale_bias,
433
      normalize_by_lengths,
434
      out);
435
}
436
bool EmbeddingLookupIdx_int32_t_float_float_true__avx2_fma(
437
    const int64_t block_size,
438
    const int64_t output_size,
439
    const int64_t index_size,
440
    const int64_t data_size,
441
    const float* input,
442
    const int* indices,
443
    const int* offsets,
444
    const float* weights,
445
    const float* scale_bias,
446
    bool normalize_by_lengths,
447
    float* out) {
448
  return EmbeddingLookupIdx_int32_t_float_float__avx2_fma<true>(
449
      block_size,
450
      output_size,
451
      index_size,
452
      data_size,
453
      input,
454
      indices,
455
      offsets,
456
      weights,
457
      scale_bias,
458
      normalize_by_lengths,
459
      out);
460
}
461

462
template <bool IS_WEIGHT_POSITIONAL>
463
static bool EmbeddingLookupIdx_int64_t_float_float__avx2_fma(
464
    const int64_t block_size,
465
    const int64_t output_size,
466
    const int64_t index_size,
467
    const int64_t data_size,
468
    const float* input,
469
    const int64_t* indices,
470
    const int64_t* offsets,
471
    const float* weights,
472
    const float* scale_bias,
473
    bool normalize_by_lengths,
474
    float* out) {
475
  const int64_t prefdist_T0 = 16;
476
  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
477
  const int64_t fused_block_size = block_size + 0;
478
  int64_t dataInd = 0;
479
  if (block_size == 128) {
480
    // unrolling 16 times
481
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
482
      float* op = &out[rangeIndex * block_size];
483
      __m256 vop0 = _mm256_setzero_ps();
484
      __m256 vop8 = _mm256_setzero_ps();
485
      __m256 vop16 = _mm256_setzero_ps();
486
      __m256 vop24 = _mm256_setzero_ps();
487
      __m256 vop32 = _mm256_setzero_ps();
488
      __m256 vop40 = _mm256_setzero_ps();
489
      __m256 vop48 = _mm256_setzero_ps();
490
      __m256 vop56 = _mm256_setzero_ps();
491
      __m256 vop64 = _mm256_setzero_ps();
492
      __m256 vop72 = _mm256_setzero_ps();
493
      __m256 vop80 = _mm256_setzero_ps();
494
      __m256 vop88 = _mm256_setzero_ps();
495
      __m256 vop96 = _mm256_setzero_ps();
496
      __m256 vop104 = _mm256_setzero_ps();
497
      __m256 vop112 = _mm256_setzero_ps();
498
      __m256 vop120 = _mm256_setzero_ps();
499
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
500
        return false;
501
      }
502
      int64_t end_offset = offsets[rangeIndex + 1];
503
      int64_t length = end_offset - offsets[rangeIndex];
504
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
505
           ++dataInd) {
506
        const int64_t idx = indices[dataInd];
507
        if (idx < 0 || idx >= data_size) {
508
          return false;
509
        }
510
        float wgt = 1.f;
511
        if (weights) {
512
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
513
        }
514
        __m256 vwgt = _mm256_set1_ps(wgt);
515
        const float* ip = &input[idx * fused_block_size];
516
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
517
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
518
            ? (dataInd + prefdist_T0)
519
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
520
            : dataInd;
521
        const int64_t idx_pref_T0 = indices[next_T0];
522
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
523
          return false;
524
        }
525
        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
526
        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
527
        _mm_prefetch(
528
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
529
        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
530
        // skip unnecessary prefetch of (&ip_next_T0[8])
531
        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
532
        _mm_prefetch(
533
            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
534
        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
535
        // skip unnecessary prefetch of (&ip_next_T0[24])
536
        vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
537
        _mm_prefetch(
538
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
539
        vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
540
        // skip unnecessary prefetch of (&ip_next_T0[40])
541
        vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
542
        _mm_prefetch(
543
            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
544
        vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
545
        // skip unnecessary prefetch of (&ip_next_T0[56])
546
        vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);
547
        _mm_prefetch(
548
            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
549
        vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);
550
        // skip unnecessary prefetch of (&ip_next_T0[72])
551
        vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);
552
        _mm_prefetch(
553
            reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);
554
        vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);
555
        // skip unnecessary prefetch of (&ip_next_T0[88])
556
        vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);
557
        _mm_prefetch(
558
            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
559
        vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);
560
        // skip unnecessary prefetch of (&ip_next_T0[104])
561
        vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);
562
        _mm_prefetch(
563
            reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);
564
        vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);
565
        // skip unnecessary prefetch of (&ip_next_T0[120])
566
      }
567
      if (!normalize_by_lengths || length == 0) {
568
        _mm256_storeu_ps(&op[0], vop0);
569
        _mm256_storeu_ps(&op[8], vop8);
570
        _mm256_storeu_ps(&op[16], vop16);
571
        _mm256_storeu_ps(&op[24], vop24);
572
        _mm256_storeu_ps(&op[32], vop32);
573
        _mm256_storeu_ps(&op[40], vop40);
574
        _mm256_storeu_ps(&op[48], vop48);
575
        _mm256_storeu_ps(&op[56], vop56);
576
        _mm256_storeu_ps(&op[64], vop64);
577
        _mm256_storeu_ps(&op[72], vop72);
578
        _mm256_storeu_ps(&op[80], vop80);
579
        _mm256_storeu_ps(&op[88], vop88);
580
        _mm256_storeu_ps(&op[96], vop96);
581
        _mm256_storeu_ps(&op[104], vop104);
582
        _mm256_storeu_ps(&op[112], vop112);
583
        _mm256_storeu_ps(&op[120], vop120);
584
      } else {
585
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
586
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
587
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
588
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
589
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
590
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
591
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
592
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
593
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
594
        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
595
        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
596
        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
597
        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
598
        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
599
        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
600
        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
601
        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
602
      }
603
    }
604
  } else if (block_size == 64) {
605
    // unrolling 8 times
606
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
607
      float* op = &out[rangeIndex * block_size];
608
      __m256 vop0 = _mm256_setzero_ps();
609
      __m256 vop8 = _mm256_setzero_ps();
610
      __m256 vop16 = _mm256_setzero_ps();
611
      __m256 vop24 = _mm256_setzero_ps();
612
      __m256 vop32 = _mm256_setzero_ps();
613
      __m256 vop40 = _mm256_setzero_ps();
614
      __m256 vop48 = _mm256_setzero_ps();
615
      __m256 vop56 = _mm256_setzero_ps();
616
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
617
        return false;
618
      }
619
      int64_t end_offset = offsets[rangeIndex + 1];
620
      int64_t length = end_offset - offsets[rangeIndex];
621
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
622
           ++dataInd) {
623
        const int64_t idx = indices[dataInd];
624
        if (idx < 0 || idx >= data_size) {
625
          return false;
626
        }
627
        float wgt = 1.f;
628
        if (weights) {
629
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
630
        }
631
        __m256 vwgt = _mm256_set1_ps(wgt);
632
        const float* ip = &input[idx * fused_block_size];
633
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
634
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
635
            ? (dataInd + prefdist_T0)
636
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
637
            : dataInd;
638
        const int64_t idx_pref_T0 = indices[next_T0];
639
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
640
          return false;
641
        }
642
        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
643
        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
644
        _mm_prefetch(
645
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
646
        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
647
        // skip unnecessary prefetch of (&ip_next_T0[8])
648
        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
649
        _mm_prefetch(
650
            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
651
        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
652
        // skip unnecessary prefetch of (&ip_next_T0[24])
653
        vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);
654
        _mm_prefetch(
655
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
656
        vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);
657
        // skip unnecessary prefetch of (&ip_next_T0[40])
658
        vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);
659
        _mm_prefetch(
660
            reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);
661
        vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);
662
        // skip unnecessary prefetch of (&ip_next_T0[56])
663
      }
664
      if (!normalize_by_lengths || length == 0) {
665
        _mm256_storeu_ps(&op[0], vop0);
666
        _mm256_storeu_ps(&op[8], vop8);
667
        _mm256_storeu_ps(&op[16], vop16);
668
        _mm256_storeu_ps(&op[24], vop24);
669
        _mm256_storeu_ps(&op[32], vop32);
670
        _mm256_storeu_ps(&op[40], vop40);
671
        _mm256_storeu_ps(&op[48], vop48);
672
        _mm256_storeu_ps(&op[56], vop56);
673
      } else {
674
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
675
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
676
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
677
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
678
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
679
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
680
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
681
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
682
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
683
      }
684
    }
685
  } else if (block_size == 32) {
686
    // unrolling 4 times
687
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
688
      float* op = &out[rangeIndex * block_size];
689
      __m256 vop0 = _mm256_setzero_ps();
690
      __m256 vop8 = _mm256_setzero_ps();
691
      __m256 vop16 = _mm256_setzero_ps();
692
      __m256 vop24 = _mm256_setzero_ps();
693
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
694
        return false;
695
      }
696
      int64_t end_offset = offsets[rangeIndex + 1];
697
      int64_t length = end_offset - offsets[rangeIndex];
698
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
699
           ++dataInd) {
700
        const int64_t idx = indices[dataInd];
701
        if (idx < 0 || idx >= data_size) {
702
          return false;
703
        }
704
        float wgt = 1.f;
705
        if (weights) {
706
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
707
        }
708
        __m256 vwgt = _mm256_set1_ps(wgt);
709
        const float* ip = &input[idx * fused_block_size];
710
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
711
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
712
            ? (dataInd + prefdist_T0)
713
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
714
            : dataInd;
715
        const int64_t idx_pref_T0 = indices[next_T0];
716
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
717
          return false;
718
        }
719
        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
720
        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
721
        _mm_prefetch(
722
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
723
        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
724
        // skip unnecessary prefetch of (&ip_next_T0[8])
725
        vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);
726
        _mm_prefetch(
727
            reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);
728
        vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);
729
        // skip unnecessary prefetch of (&ip_next_T0[24])
730
      }
731
      if (!normalize_by_lengths || length == 0) {
732
        _mm256_storeu_ps(&op[0], vop0);
733
        _mm256_storeu_ps(&op[8], vop8);
734
        _mm256_storeu_ps(&op[16], vop16);
735
        _mm256_storeu_ps(&op[24], vop24);
736
      } else {
737
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
738
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
739
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
740
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
741
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
742
      }
743
    }
744
  } else if (block_size == 16) {
745
    // unrolling 2 times
746
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
747
      float* op = &out[rangeIndex * block_size];
748
      __m256 vop0 = _mm256_setzero_ps();
749
      __m256 vop8 = _mm256_setzero_ps();
750
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
751
        return false;
752
      }
753
      int64_t end_offset = offsets[rangeIndex + 1];
754
      int64_t length = end_offset - offsets[rangeIndex];
755
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
756
           ++dataInd) {
757
        const int64_t idx = indices[dataInd];
758
        if (idx < 0 || idx >= data_size) {
759
          return false;
760
        }
761
        float wgt = 1.f;
762
        if (weights) {
763
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
764
        }
765
        __m256 vwgt = _mm256_set1_ps(wgt);
766
        const float* ip = &input[idx * fused_block_size];
767
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
768
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
769
            ? (dataInd + prefdist_T0)
770
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
771
            : dataInd;
772
        const int64_t idx_pref_T0 = indices[next_T0];
773
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
774
          return false;
775
        }
776
        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
777
        vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);
778
        _mm_prefetch(
779
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
780
        vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);
781
        // skip unnecessary prefetch of (&ip_next_T0[8])
782
      }
783
      if (!normalize_by_lengths || length == 0) {
784
        _mm256_storeu_ps(&op[0], vop0);
785
        _mm256_storeu_ps(&op[8], vop8);
786
      } else {
787
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
788
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
789
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
790
      }
791
    }
792
  } else {
793
    // generic code
794
    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
795
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
796
      float* op = &out[rangeIndex * block_size];
797
      int64_t j = 0;
798
      for (; j + 8 <= block_size; j += 8) {
799
        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
800
      }
801
      for (; j < block_size; j++) {
802
        op[j] = 0.0f;
803
      }
804
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
805
        return false;
806
      }
807
      int64_t end_offset = offsets[rangeIndex + 1];
808
      int64_t length = end_offset - offsets[rangeIndex];
809
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
810
           ++dataInd) {
811
        const int64_t idx = indices[dataInd];
812
        if (idx < 0 || idx >= data_size) {
813
          return false;
814
        }
815
        float wgt = 1.f;
816
        if (weights) {
817
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
818
        }
819
        __m256 vwgt = _mm256_set1_ps(wgt);
820
        const float* ip = &input[idx * fused_block_size];
821
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
822
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
823
            ? (dataInd + prefdist_T0)
824
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
825
            : dataInd;
826
        const int64_t idx_pref_T0 = indices[next_T0];
827
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
828
          return false;
829
        }
830
        const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
831
        j = 0;
832
        for (; j + 8 <= block_size; j += 8) {
833
          _mm256_storeu_ps(
834
              &op[j],
835
              _mm256_fmadd_ps(
836
                  vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));
837
          _mm_prefetch(
838
              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
839
        }
840
        for (; j < block_size; j++) {
841
          op[j] = std::fma(wgt, ip[j], op[j]);
842
        }
843
      }
844
      if (normalize_by_lengths && length) {
845
        float len_inv = 1.0f / length;
846
        __m256 vlen_inv = _mm256_set1_ps(len_inv);
847
        j = 0;
848
        for (; j + 8 <= block_size; j += 8) {
849
          _mm256_storeu_ps(
850
              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
851
        }
852
        for (; j < block_size; j++) {
853
          op[j] = len_inv * op[j];
854
        }
855
      }
856
    }
857
  }
858
  return dataInd == index_size;
859
}
860
bool EmbeddingLookupIdx_int64_t_float_float_false__avx2_fma(
861
    const int64_t block_size,
862
    const int64_t output_size,
863
    const int64_t index_size,
864
    const int64_t data_size,
865
    const float* input,
866
    const int64_t* indices,
867
    const int64_t* offsets,
868
    const float* weights,
869
    const float* scale_bias,
870
    bool normalize_by_lengths,
871
    float* out) {
872
  return EmbeddingLookupIdx_int64_t_float_float__avx2_fma<false>(
873
      block_size,
874
      output_size,
875
      index_size,
876
      data_size,
877
      input,
878
      indices,
879
      offsets,
880
      weights,
881
      scale_bias,
882
      normalize_by_lengths,
883
      out);
884
}
885
bool EmbeddingLookupIdx_int64_t_float_float_true__avx2_fma(
886
    const int64_t block_size,
887
    const int64_t output_size,
888
    const int64_t index_size,
889
    const int64_t data_size,
890
    const float* input,
891
    const int64_t* indices,
892
    const int64_t* offsets,
893
    const float* weights,
894
    const float* scale_bias,
895
    bool normalize_by_lengths,
896
    float* out) {
897
  return EmbeddingLookupIdx_int64_t_float_float__avx2_fma<true>(
898
      block_size,
899
      output_size,
900
      index_size,
901
      data_size,
902
      input,
903
      indices,
904
      offsets,
905
      weights,
906
      scale_bias,
907
      normalize_by_lengths,
908
      out);
909
}
910

911
template <bool IS_WEIGHT_POSITIONAL>
912
static bool EmbeddingLookupIdx_int32_t_half_float__avx2_fma(
913
    const int64_t block_size,
914
    const int64_t output_size,
915
    const int64_t index_size,
916
    const int64_t data_size,
917
    const at::Half* input,
918
    const int* indices,
919
    const int* offsets,
920
    const float* weights,
921
    const float* scale_bias,
922
    bool normalize_by_lengths,
923
    float* out) {
924
  const int prefdist_T0 = 16;
925
  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
926
  const int fused_block_size = block_size + 0;
927
  int64_t dataInd = 0;
928
  if (block_size == 128) {
929
    // unrolling 16 times
930
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
931
      float* op = &out[rangeIndex * block_size];
932
      __m256 vop0 = _mm256_setzero_ps();
933
      __m256 vop8 = _mm256_setzero_ps();
934
      __m256 vop16 = _mm256_setzero_ps();
935
      __m256 vop24 = _mm256_setzero_ps();
936
      __m256 vop32 = _mm256_setzero_ps();
937
      __m256 vop40 = _mm256_setzero_ps();
938
      __m256 vop48 = _mm256_setzero_ps();
939
      __m256 vop56 = _mm256_setzero_ps();
940
      __m256 vop64 = _mm256_setzero_ps();
941
      __m256 vop72 = _mm256_setzero_ps();
942
      __m256 vop80 = _mm256_setzero_ps();
943
      __m256 vop88 = _mm256_setzero_ps();
944
      __m256 vop96 = _mm256_setzero_ps();
945
      __m256 vop104 = _mm256_setzero_ps();
946
      __m256 vop112 = _mm256_setzero_ps();
947
      __m256 vop120 = _mm256_setzero_ps();
948
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
949
        return false;
950
      }
951
      int64_t end_offset = offsets[rangeIndex + 1];
952
      int64_t length = end_offset - offsets[rangeIndex];
953
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
954
           ++dataInd) {
955
        const int idx = indices[dataInd];
956
        if (idx < 0 || idx >= data_size) {
957
          return false;
958
        }
959
        float wgt = 1.f;
960
        if (weights) {
961
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
962
        }
963
        __m256 vwgt = _mm256_set1_ps(wgt);
964
        const at::Half* ip = &input[idx * fused_block_size];
965
        const int next_T0 = (dataInd < index_size - prefdist_T0)
966
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
967
            ? (dataInd + prefdist_T0)
968
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
969
            : dataInd;
970
        const int idx_pref_T0 = indices[next_T0];
971
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
972
          return false;
973
        }
974
        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
975
        vop0 = _mm256_fmadd_ps(
976
            vwgt,
977
            _mm256_cvtph_ps(
978
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
979
            vop0);
980
        _mm_prefetch(
981
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
982
        vop8 = _mm256_fmadd_ps(
983
            vwgt,
984
            _mm256_cvtph_ps(
985
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
986
            vop8);
987
        // skip unnecessary prefetch of (&ip_next_T0[8])
988
        vop16 = _mm256_fmadd_ps(
989
            vwgt,
990
            _mm256_cvtph_ps(
991
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
992
            vop16);
993
        // skip unnecessary prefetch of (&ip_next_T0[16])
994
        vop24 = _mm256_fmadd_ps(
995
            vwgt,
996
            _mm256_cvtph_ps(
997
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
998
            vop24);
999
        // skip unnecessary prefetch of (&ip_next_T0[24])
1000
        vop32 = _mm256_fmadd_ps(
1001
            vwgt,
1002
            _mm256_cvtph_ps(
1003
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1004
            vop32);
1005
        _mm_prefetch(
1006
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1007
        vop40 = _mm256_fmadd_ps(
1008
            vwgt,
1009
            _mm256_cvtph_ps(
1010
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1011
            vop40);
1012
        // skip unnecessary prefetch of (&ip_next_T0[40])
1013
        vop48 = _mm256_fmadd_ps(
1014
            vwgt,
1015
            _mm256_cvtph_ps(
1016
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1017
            vop48);
1018
        // skip unnecessary prefetch of (&ip_next_T0[48])
1019
        vop56 = _mm256_fmadd_ps(
1020
            vwgt,
1021
            _mm256_cvtph_ps(
1022
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1023
            vop56);
1024
        // skip unnecessary prefetch of (&ip_next_T0[56])
1025
        vop64 = _mm256_fmadd_ps(
1026
            vwgt,
1027
            _mm256_cvtph_ps(
1028
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1029
            vop64);
1030
        _mm_prefetch(
1031
            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
1032
        vop72 = _mm256_fmadd_ps(
1033
            vwgt,
1034
            _mm256_cvtph_ps(
1035
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1036
            vop72);
1037
        // skip unnecessary prefetch of (&ip_next_T0[72])
1038
        vop80 = _mm256_fmadd_ps(
1039
            vwgt,
1040
            _mm256_cvtph_ps(
1041
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1042
            vop80);
1043
        // skip unnecessary prefetch of (&ip_next_T0[80])
1044
        vop88 = _mm256_fmadd_ps(
1045
            vwgt,
1046
            _mm256_cvtph_ps(
1047
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1048
            vop88);
1049
        // skip unnecessary prefetch of (&ip_next_T0[88])
1050
        vop96 = _mm256_fmadd_ps(
1051
            vwgt,
1052
            _mm256_cvtph_ps(
1053
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1054
            vop96);
1055
        _mm_prefetch(
1056
            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1057
        vop104 = _mm256_fmadd_ps(
1058
            vwgt,
1059
            _mm256_cvtph_ps(
1060
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1061
            vop104);
1062
        // skip unnecessary prefetch of (&ip_next_T0[104])
1063
        vop112 = _mm256_fmadd_ps(
1064
            vwgt,
1065
            _mm256_cvtph_ps(
1066
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1067
            vop112);
1068
        // skip unnecessary prefetch of (&ip_next_T0[112])
1069
        vop120 = _mm256_fmadd_ps(
1070
            vwgt,
1071
            _mm256_cvtph_ps(
1072
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1073
            vop120);
1074
        // skip unnecessary prefetch of (&ip_next_T0[120])
1075
      }
1076
      if (!normalize_by_lengths || length == 0) {
1077
        _mm256_storeu_ps(&op[0], vop0);
1078
        _mm256_storeu_ps(&op[8], vop8);
1079
        _mm256_storeu_ps(&op[16], vop16);
1080
        _mm256_storeu_ps(&op[24], vop24);
1081
        _mm256_storeu_ps(&op[32], vop32);
1082
        _mm256_storeu_ps(&op[40], vop40);
1083
        _mm256_storeu_ps(&op[48], vop48);
1084
        _mm256_storeu_ps(&op[56], vop56);
1085
        _mm256_storeu_ps(&op[64], vop64);
1086
        _mm256_storeu_ps(&op[72], vop72);
1087
        _mm256_storeu_ps(&op[80], vop80);
1088
        _mm256_storeu_ps(&op[88], vop88);
1089
        _mm256_storeu_ps(&op[96], vop96);
1090
        _mm256_storeu_ps(&op[104], vop104);
1091
        _mm256_storeu_ps(&op[112], vop112);
1092
        _mm256_storeu_ps(&op[120], vop120);
1093
      } else {
1094
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1095
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1096
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1097
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1098
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1099
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1100
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1101
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1102
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1103
        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1104
        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1105
        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1106
        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1107
        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1108
        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1109
        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1110
        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1111
      }
1112
    }
1113
  } else if (block_size == 64) {
1114
    // unrolling 8 times
1115
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1116
      float* op = &out[rangeIndex * block_size];
1117
      __m256 vop0 = _mm256_setzero_ps();
1118
      __m256 vop8 = _mm256_setzero_ps();
1119
      __m256 vop16 = _mm256_setzero_ps();
1120
      __m256 vop24 = _mm256_setzero_ps();
1121
      __m256 vop32 = _mm256_setzero_ps();
1122
      __m256 vop40 = _mm256_setzero_ps();
1123
      __m256 vop48 = _mm256_setzero_ps();
1124
      __m256 vop56 = _mm256_setzero_ps();
1125
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
1126
        return false;
1127
      }
1128
      int64_t end_offset = offsets[rangeIndex + 1];
1129
      int64_t length = end_offset - offsets[rangeIndex];
1130
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1131
           ++dataInd) {
1132
        const int idx = indices[dataInd];
1133
        if (idx < 0 || idx >= data_size) {
1134
          return false;
1135
        }
1136
        float wgt = 1.f;
1137
        if (weights) {
1138
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1139
        }
1140
        __m256 vwgt = _mm256_set1_ps(wgt);
1141
        const at::Half* ip = &input[idx * fused_block_size];
1142
        const int next_T0 = (dataInd < index_size - prefdist_T0)
1143
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1144
            ? (dataInd + prefdist_T0)
1145
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1146
            : dataInd;
1147
        const int idx_pref_T0 = indices[next_T0];
1148
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1149
          return false;
1150
        }
1151
        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1152
        vop0 = _mm256_fmadd_ps(
1153
            vwgt,
1154
            _mm256_cvtph_ps(
1155
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1156
            vop0);
1157
        _mm_prefetch(
1158
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1159
        vop8 = _mm256_fmadd_ps(
1160
            vwgt,
1161
            _mm256_cvtph_ps(
1162
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1163
            vop8);
1164
        // skip unnecessary prefetch of (&ip_next_T0[8])
1165
        vop16 = _mm256_fmadd_ps(
1166
            vwgt,
1167
            _mm256_cvtph_ps(
1168
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1169
            vop16);
1170
        // skip unnecessary prefetch of (&ip_next_T0[16])
1171
        vop24 = _mm256_fmadd_ps(
1172
            vwgt,
1173
            _mm256_cvtph_ps(
1174
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1175
            vop24);
1176
        // skip unnecessary prefetch of (&ip_next_T0[24])
1177
        vop32 = _mm256_fmadd_ps(
1178
            vwgt,
1179
            _mm256_cvtph_ps(
1180
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1181
            vop32);
1182
        _mm_prefetch(
1183
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1184
        vop40 = _mm256_fmadd_ps(
1185
            vwgt,
1186
            _mm256_cvtph_ps(
1187
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1188
            vop40);
1189
        // skip unnecessary prefetch of (&ip_next_T0[40])
1190
        vop48 = _mm256_fmadd_ps(
1191
            vwgt,
1192
            _mm256_cvtph_ps(
1193
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1194
            vop48);
1195
        // skip unnecessary prefetch of (&ip_next_T0[48])
1196
        vop56 = _mm256_fmadd_ps(
1197
            vwgt,
1198
            _mm256_cvtph_ps(
1199
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1200
            vop56);
1201
        // skip unnecessary prefetch of (&ip_next_T0[56])
1202
      }
1203
      if (!normalize_by_lengths || length == 0) {
1204
        _mm256_storeu_ps(&op[0], vop0);
1205
        _mm256_storeu_ps(&op[8], vop8);
1206
        _mm256_storeu_ps(&op[16], vop16);
1207
        _mm256_storeu_ps(&op[24], vop24);
1208
        _mm256_storeu_ps(&op[32], vop32);
1209
        _mm256_storeu_ps(&op[40], vop40);
1210
        _mm256_storeu_ps(&op[48], vop48);
1211
        _mm256_storeu_ps(&op[56], vop56);
1212
      } else {
1213
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1214
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1215
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1216
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1217
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1218
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1219
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1220
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1221
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1222
      }
1223
    }
1224
  } else if (block_size == 32) {
1225
    // unrolling 4 times
1226
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1227
      float* op = &out[rangeIndex * block_size];
1228
      __m256 vop0 = _mm256_setzero_ps();
1229
      __m256 vop8 = _mm256_setzero_ps();
1230
      __m256 vop16 = _mm256_setzero_ps();
1231
      __m256 vop24 = _mm256_setzero_ps();
1232
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
1233
        return false;
1234
      }
1235
      int64_t end_offset = offsets[rangeIndex + 1];
1236
      int64_t length = end_offset - offsets[rangeIndex];
1237
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1238
           ++dataInd) {
1239
        const int idx = indices[dataInd];
1240
        if (idx < 0 || idx >= data_size) {
1241
          return false;
1242
        }
1243
        float wgt = 1.f;
1244
        if (weights) {
1245
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1246
        }
1247
        __m256 vwgt = _mm256_set1_ps(wgt);
1248
        const at::Half* ip = &input[idx * fused_block_size];
1249
        const int next_T0 = (dataInd < index_size - prefdist_T0)
1250
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1251
            ? (dataInd + prefdist_T0)
1252
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1253
            : dataInd;
1254
        const int idx_pref_T0 = indices[next_T0];
1255
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1256
          return false;
1257
        }
1258
        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1259
        vop0 = _mm256_fmadd_ps(
1260
            vwgt,
1261
            _mm256_cvtph_ps(
1262
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1263
            vop0);
1264
        _mm_prefetch(
1265
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1266
        vop8 = _mm256_fmadd_ps(
1267
            vwgt,
1268
            _mm256_cvtph_ps(
1269
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1270
            vop8);
1271
        // skip unnecessary prefetch of (&ip_next_T0[8])
1272
        vop16 = _mm256_fmadd_ps(
1273
            vwgt,
1274
            _mm256_cvtph_ps(
1275
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1276
            vop16);
1277
        // skip unnecessary prefetch of (&ip_next_T0[16])
1278
        vop24 = _mm256_fmadd_ps(
1279
            vwgt,
1280
            _mm256_cvtph_ps(
1281
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1282
            vop24);
1283
        // skip unnecessary prefetch of (&ip_next_T0[24])
1284
      }
1285
      if (!normalize_by_lengths || length == 0) {
1286
        _mm256_storeu_ps(&op[0], vop0);
1287
        _mm256_storeu_ps(&op[8], vop8);
1288
        _mm256_storeu_ps(&op[16], vop16);
1289
        _mm256_storeu_ps(&op[24], vop24);
1290
      } else {
1291
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1292
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1293
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1294
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1295
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1296
      }
1297
    }
1298
  } else if (block_size == 16) {
1299
    // unrolling 2 times
1300
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1301
      float* op = &out[rangeIndex * block_size];
1302
      __m256 vop0 = _mm256_setzero_ps();
1303
      __m256 vop8 = _mm256_setzero_ps();
1304
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
1305
        return false;
1306
      }
1307
      int64_t end_offset = offsets[rangeIndex + 1];
1308
      int64_t length = end_offset - offsets[rangeIndex];
1309
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1310
           ++dataInd) {
1311
        const int idx = indices[dataInd];
1312
        if (idx < 0 || idx >= data_size) {
1313
          return false;
1314
        }
1315
        float wgt = 1.f;
1316
        if (weights) {
1317
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1318
        }
1319
        __m256 vwgt = _mm256_set1_ps(wgt);
1320
        const at::Half* ip = &input[idx * fused_block_size];
1321
        const int next_T0 = (dataInd < index_size - prefdist_T0)
1322
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1323
            ? (dataInd + prefdist_T0)
1324
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1325
            : dataInd;
1326
        const int idx_pref_T0 = indices[next_T0];
1327
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1328
          return false;
1329
        }
1330
        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1331
        vop0 = _mm256_fmadd_ps(
1332
            vwgt,
1333
            _mm256_cvtph_ps(
1334
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1335
            vop0);
1336
        _mm_prefetch(
1337
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1338
        vop8 = _mm256_fmadd_ps(
1339
            vwgt,
1340
            _mm256_cvtph_ps(
1341
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1342
            vop8);
1343
        // skip unnecessary prefetch of (&ip_next_T0[8])
1344
      }
1345
      if (!normalize_by_lengths || length == 0) {
1346
        _mm256_storeu_ps(&op[0], vop0);
1347
        _mm256_storeu_ps(&op[8], vop8);
1348
      } else {
1349
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1350
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1351
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1352
      }
1353
    }
1354
  } else {
1355
    // generic code
1356
    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
1357
    alignas(64) at::Half vtmp1[8] = {0};
1358
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1359
      float* op = &out[rangeIndex * block_size];
1360
      int64_t j = 0;
1361
      for (; j + 8 <= block_size; j += 8) {
1362
        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1363
      }
1364
      for (; j < block_size; j++) {
1365
        op[j] = 0.0f;
1366
      }
1367
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
1368
        return false;
1369
      }
1370
      int64_t end_offset = offsets[rangeIndex + 1];
1371
      int64_t length = end_offset - offsets[rangeIndex];
1372
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1373
           ++dataInd) {
1374
        const int idx = indices[dataInd];
1375
        if (idx < 0 || idx >= data_size) {
1376
          return false;
1377
        }
1378
        float wgt = 1.f;
1379
        if (weights) {
1380
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1381
        }
1382
        __m256 vwgt = _mm256_set1_ps(wgt);
1383
        const at::Half* ip = &input[idx * fused_block_size];
1384
        const int next_T0 = (dataInd < index_size - prefdist_T0)
1385
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1386
            ? (dataInd + prefdist_T0)
1387
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1388
            : dataInd;
1389
        const int idx_pref_T0 = indices[next_T0];
1390
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1391
          return false;
1392
        }
1393
        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1394
        j = 0;
1395
        for (; j + 8 <= block_size; j += 8) {
1396
          _mm256_storeu_ps(
1397
              &op[j],
1398
              _mm256_fmadd_ps(
1399
                  vwgt,
1400
                  _mm256_cvtph_ps(_mm_loadu_si128(
1401
                      reinterpret_cast<const __m128i*>(&ip[j]))),
1402
                  _mm256_loadu_ps(&op[j])));
1403
          _mm_prefetch(
1404
              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1405
        }
1406
        for (; j < block_size; j++) {
1407
          vtmp1[0] = ip[j];
1408
          __m256 vtmp2 =
1409
              _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));
1410
          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
1411
        }
1412
      }
1413
      if (normalize_by_lengths && length) {
1414
        float len_inv = 1.0f / length;
1415
        __m256 vlen_inv = _mm256_set1_ps(len_inv);
1416
        j = 0;
1417
        for (; j + 8 <= block_size; j += 8) {
1418
          _mm256_storeu_ps(
1419
              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1420
        }
1421
        for (; j < block_size; j++) {
1422
          op[j] = len_inv * op[j];
1423
        }
1424
      }
1425
    }
1426
  }
1427
  return dataInd == index_size;
1428
}
1429
bool EmbeddingLookupIdx_int32_t_half_float_false__avx2_fma(
1430
    const int64_t block_size,
1431
    const int64_t output_size,
1432
    const int64_t index_size,
1433
    const int64_t data_size,
1434
    const at::Half* input,
1435
    const int* indices,
1436
    const int* offsets,
1437
    const float* weights,
1438
    const float* scale_bias,
1439
    bool normalize_by_lengths,
1440
    float* out) {
1441
  return EmbeddingLookupIdx_int32_t_half_float__avx2_fma<false>(
1442
      block_size,
1443
      output_size,
1444
      index_size,
1445
      data_size,
1446
      input,
1447
      indices,
1448
      offsets,
1449
      weights,
1450
      scale_bias,
1451
      normalize_by_lengths,
1452
      out);
1453
}
1454
bool EmbeddingLookupIdx_int32_t_half_float_true__avx2_fma(
1455
    const int64_t block_size,
1456
    const int64_t output_size,
1457
    const int64_t index_size,
1458
    const int64_t data_size,
1459
    const at::Half* input,
1460
    const int* indices,
1461
    const int* offsets,
1462
    const float* weights,
1463
    const float* scale_bias,
1464
    bool normalize_by_lengths,
1465
    float* out) {
1466
  return EmbeddingLookupIdx_int32_t_half_float__avx2_fma<true>(
1467
      block_size,
1468
      output_size,
1469
      index_size,
1470
      data_size,
1471
      input,
1472
      indices,
1473
      offsets,
1474
      weights,
1475
      scale_bias,
1476
      normalize_by_lengths,
1477
      out);
1478
}
1479

1480
template <bool IS_WEIGHT_POSITIONAL>
1481
static bool EmbeddingLookupIdx_int64_t_half_float__avx2_fma(
1482
    const int64_t block_size,
1483
    const int64_t output_size,
1484
    const int64_t index_size,
1485
    const int64_t data_size,
1486
    const at::Half* input,
1487
    const int64_t* indices,
1488
    const int64_t* offsets,
1489
    const float* weights,
1490
    const float* scale_bias,
1491
    bool normalize_by_lengths,
1492
    float* out) {
1493
  const int64_t prefdist_T0 = 16;
1494
  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1495
  const int64_t fused_block_size = block_size + 0;
1496
  int64_t dataInd = 0;
1497
  if (block_size == 128) {
1498
    // unrolling 16 times
1499
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1500
      float* op = &out[rangeIndex * block_size];
1501
      __m256 vop0 = _mm256_setzero_ps();
1502
      __m256 vop8 = _mm256_setzero_ps();
1503
      __m256 vop16 = _mm256_setzero_ps();
1504
      __m256 vop24 = _mm256_setzero_ps();
1505
      __m256 vop32 = _mm256_setzero_ps();
1506
      __m256 vop40 = _mm256_setzero_ps();
1507
      __m256 vop48 = _mm256_setzero_ps();
1508
      __m256 vop56 = _mm256_setzero_ps();
1509
      __m256 vop64 = _mm256_setzero_ps();
1510
      __m256 vop72 = _mm256_setzero_ps();
1511
      __m256 vop80 = _mm256_setzero_ps();
1512
      __m256 vop88 = _mm256_setzero_ps();
1513
      __m256 vop96 = _mm256_setzero_ps();
1514
      __m256 vop104 = _mm256_setzero_ps();
1515
      __m256 vop112 = _mm256_setzero_ps();
1516
      __m256 vop120 = _mm256_setzero_ps();
1517
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
1518
        return false;
1519
      }
1520
      int64_t end_offset = offsets[rangeIndex + 1];
1521
      int64_t length = end_offset - offsets[rangeIndex];
1522
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1523
           ++dataInd) {
1524
        const int64_t idx = indices[dataInd];
1525
        if (idx < 0 || idx >= data_size) {
1526
          return false;
1527
        }
1528
        float wgt = 1.f;
1529
        if (weights) {
1530
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1531
        }
1532
        __m256 vwgt = _mm256_set1_ps(wgt);
1533
        const at::Half* ip = &input[idx * fused_block_size];
1534
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1535
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1536
            ? (dataInd + prefdist_T0)
1537
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1538
            : dataInd;
1539
        const int64_t idx_pref_T0 = indices[next_T0];
1540
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1541
          return false;
1542
        }
1543
        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1544
        vop0 = _mm256_fmadd_ps(
1545
            vwgt,
1546
            _mm256_cvtph_ps(
1547
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1548
            vop0);
1549
        _mm_prefetch(
1550
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1551
        vop8 = _mm256_fmadd_ps(
1552
            vwgt,
1553
            _mm256_cvtph_ps(
1554
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1555
            vop8);
1556
        // skip unnecessary prefetch of (&ip_next_T0[8])
1557
        vop16 = _mm256_fmadd_ps(
1558
            vwgt,
1559
            _mm256_cvtph_ps(
1560
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1561
            vop16);
1562
        // skip unnecessary prefetch of (&ip_next_T0[16])
1563
        vop24 = _mm256_fmadd_ps(
1564
            vwgt,
1565
            _mm256_cvtph_ps(
1566
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1567
            vop24);
1568
        // skip unnecessary prefetch of (&ip_next_T0[24])
1569
        vop32 = _mm256_fmadd_ps(
1570
            vwgt,
1571
            _mm256_cvtph_ps(
1572
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1573
            vop32);
1574
        _mm_prefetch(
1575
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1576
        vop40 = _mm256_fmadd_ps(
1577
            vwgt,
1578
            _mm256_cvtph_ps(
1579
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1580
            vop40);
1581
        // skip unnecessary prefetch of (&ip_next_T0[40])
1582
        vop48 = _mm256_fmadd_ps(
1583
            vwgt,
1584
            _mm256_cvtph_ps(
1585
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1586
            vop48);
1587
        // skip unnecessary prefetch of (&ip_next_T0[48])
1588
        vop56 = _mm256_fmadd_ps(
1589
            vwgt,
1590
            _mm256_cvtph_ps(
1591
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1592
            vop56);
1593
        // skip unnecessary prefetch of (&ip_next_T0[56])
1594
        vop64 = _mm256_fmadd_ps(
1595
            vwgt,
1596
            _mm256_cvtph_ps(
1597
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),
1598
            vop64);
1599
        _mm_prefetch(
1600
            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
1601
        vop72 = _mm256_fmadd_ps(
1602
            vwgt,
1603
            _mm256_cvtph_ps(
1604
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),
1605
            vop72);
1606
        // skip unnecessary prefetch of (&ip_next_T0[72])
1607
        vop80 = _mm256_fmadd_ps(
1608
            vwgt,
1609
            _mm256_cvtph_ps(
1610
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),
1611
            vop80);
1612
        // skip unnecessary prefetch of (&ip_next_T0[80])
1613
        vop88 = _mm256_fmadd_ps(
1614
            vwgt,
1615
            _mm256_cvtph_ps(
1616
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),
1617
            vop88);
1618
        // skip unnecessary prefetch of (&ip_next_T0[88])
1619
        vop96 = _mm256_fmadd_ps(
1620
            vwgt,
1621
            _mm256_cvtph_ps(
1622
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),
1623
            vop96);
1624
        _mm_prefetch(
1625
            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
1626
        vop104 = _mm256_fmadd_ps(
1627
            vwgt,
1628
            _mm256_cvtph_ps(
1629
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),
1630
            vop104);
1631
        // skip unnecessary prefetch of (&ip_next_T0[104])
1632
        vop112 = _mm256_fmadd_ps(
1633
            vwgt,
1634
            _mm256_cvtph_ps(
1635
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),
1636
            vop112);
1637
        // skip unnecessary prefetch of (&ip_next_T0[112])
1638
        vop120 = _mm256_fmadd_ps(
1639
            vwgt,
1640
            _mm256_cvtph_ps(
1641
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),
1642
            vop120);
1643
        // skip unnecessary prefetch of (&ip_next_T0[120])
1644
      }
1645
      if (!normalize_by_lengths || length == 0) {
1646
        _mm256_storeu_ps(&op[0], vop0);
1647
        _mm256_storeu_ps(&op[8], vop8);
1648
        _mm256_storeu_ps(&op[16], vop16);
1649
        _mm256_storeu_ps(&op[24], vop24);
1650
        _mm256_storeu_ps(&op[32], vop32);
1651
        _mm256_storeu_ps(&op[40], vop40);
1652
        _mm256_storeu_ps(&op[48], vop48);
1653
        _mm256_storeu_ps(&op[56], vop56);
1654
        _mm256_storeu_ps(&op[64], vop64);
1655
        _mm256_storeu_ps(&op[72], vop72);
1656
        _mm256_storeu_ps(&op[80], vop80);
1657
        _mm256_storeu_ps(&op[88], vop88);
1658
        _mm256_storeu_ps(&op[96], vop96);
1659
        _mm256_storeu_ps(&op[104], vop104);
1660
        _mm256_storeu_ps(&op[112], vop112);
1661
        _mm256_storeu_ps(&op[120], vop120);
1662
      } else {
1663
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1664
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1665
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1666
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1667
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1668
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1669
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1670
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1671
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1672
        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
1673
        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
1674
        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
1675
        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
1676
        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
1677
        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
1678
        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
1679
        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
1680
      }
1681
    }
1682
  } else if (block_size == 64) {
1683
    // unrolling 8 times
1684
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1685
      float* op = &out[rangeIndex * block_size];
1686
      __m256 vop0 = _mm256_setzero_ps();
1687
      __m256 vop8 = _mm256_setzero_ps();
1688
      __m256 vop16 = _mm256_setzero_ps();
1689
      __m256 vop24 = _mm256_setzero_ps();
1690
      __m256 vop32 = _mm256_setzero_ps();
1691
      __m256 vop40 = _mm256_setzero_ps();
1692
      __m256 vop48 = _mm256_setzero_ps();
1693
      __m256 vop56 = _mm256_setzero_ps();
1694
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
1695
        return false;
1696
      }
1697
      int64_t end_offset = offsets[rangeIndex + 1];
1698
      int64_t length = end_offset - offsets[rangeIndex];
1699
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1700
           ++dataInd) {
1701
        const int64_t idx = indices[dataInd];
1702
        if (idx < 0 || idx >= data_size) {
1703
          return false;
1704
        }
1705
        float wgt = 1.f;
1706
        if (weights) {
1707
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1708
        }
1709
        __m256 vwgt = _mm256_set1_ps(wgt);
1710
        const at::Half* ip = &input[idx * fused_block_size];
1711
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1712
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1713
            ? (dataInd + prefdist_T0)
1714
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1715
            : dataInd;
1716
        const int64_t idx_pref_T0 = indices[next_T0];
1717
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1718
          return false;
1719
        }
1720
        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1721
        vop0 = _mm256_fmadd_ps(
1722
            vwgt,
1723
            _mm256_cvtph_ps(
1724
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1725
            vop0);
1726
        _mm_prefetch(
1727
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1728
        vop8 = _mm256_fmadd_ps(
1729
            vwgt,
1730
            _mm256_cvtph_ps(
1731
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1732
            vop8);
1733
        // skip unnecessary prefetch of (&ip_next_T0[8])
1734
        vop16 = _mm256_fmadd_ps(
1735
            vwgt,
1736
            _mm256_cvtph_ps(
1737
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1738
            vop16);
1739
        // skip unnecessary prefetch of (&ip_next_T0[16])
1740
        vop24 = _mm256_fmadd_ps(
1741
            vwgt,
1742
            _mm256_cvtph_ps(
1743
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1744
            vop24);
1745
        // skip unnecessary prefetch of (&ip_next_T0[24])
1746
        vop32 = _mm256_fmadd_ps(
1747
            vwgt,
1748
            _mm256_cvtph_ps(
1749
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),
1750
            vop32);
1751
        _mm_prefetch(
1752
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
1753
        vop40 = _mm256_fmadd_ps(
1754
            vwgt,
1755
            _mm256_cvtph_ps(
1756
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),
1757
            vop40);
1758
        // skip unnecessary prefetch of (&ip_next_T0[40])
1759
        vop48 = _mm256_fmadd_ps(
1760
            vwgt,
1761
            _mm256_cvtph_ps(
1762
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),
1763
            vop48);
1764
        // skip unnecessary prefetch of (&ip_next_T0[48])
1765
        vop56 = _mm256_fmadd_ps(
1766
            vwgt,
1767
            _mm256_cvtph_ps(
1768
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),
1769
            vop56);
1770
        // skip unnecessary prefetch of (&ip_next_T0[56])
1771
      }
1772
      if (!normalize_by_lengths || length == 0) {
1773
        _mm256_storeu_ps(&op[0], vop0);
1774
        _mm256_storeu_ps(&op[8], vop8);
1775
        _mm256_storeu_ps(&op[16], vop16);
1776
        _mm256_storeu_ps(&op[24], vop24);
1777
        _mm256_storeu_ps(&op[32], vop32);
1778
        _mm256_storeu_ps(&op[40], vop40);
1779
        _mm256_storeu_ps(&op[48], vop48);
1780
        _mm256_storeu_ps(&op[56], vop56);
1781
      } else {
1782
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1783
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1784
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1785
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1786
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1787
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
1788
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
1789
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
1790
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
1791
      }
1792
    }
1793
  } else if (block_size == 32) {
1794
    // unrolling 4 times
1795
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1796
      float* op = &out[rangeIndex * block_size];
1797
      __m256 vop0 = _mm256_setzero_ps();
1798
      __m256 vop8 = _mm256_setzero_ps();
1799
      __m256 vop16 = _mm256_setzero_ps();
1800
      __m256 vop24 = _mm256_setzero_ps();
1801
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
1802
        return false;
1803
      }
1804
      int64_t end_offset = offsets[rangeIndex + 1];
1805
      int64_t length = end_offset - offsets[rangeIndex];
1806
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1807
           ++dataInd) {
1808
        const int64_t idx = indices[dataInd];
1809
        if (idx < 0 || idx >= data_size) {
1810
          return false;
1811
        }
1812
        float wgt = 1.f;
1813
        if (weights) {
1814
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1815
        }
1816
        __m256 vwgt = _mm256_set1_ps(wgt);
1817
        const at::Half* ip = &input[idx * fused_block_size];
1818
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1819
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1820
            ? (dataInd + prefdist_T0)
1821
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1822
            : dataInd;
1823
        const int64_t idx_pref_T0 = indices[next_T0];
1824
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1825
          return false;
1826
        }
1827
        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1828
        vop0 = _mm256_fmadd_ps(
1829
            vwgt,
1830
            _mm256_cvtph_ps(
1831
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1832
            vop0);
1833
        _mm_prefetch(
1834
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1835
        vop8 = _mm256_fmadd_ps(
1836
            vwgt,
1837
            _mm256_cvtph_ps(
1838
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1839
            vop8);
1840
        // skip unnecessary prefetch of (&ip_next_T0[8])
1841
        vop16 = _mm256_fmadd_ps(
1842
            vwgt,
1843
            _mm256_cvtph_ps(
1844
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),
1845
            vop16);
1846
        // skip unnecessary prefetch of (&ip_next_T0[16])
1847
        vop24 = _mm256_fmadd_ps(
1848
            vwgt,
1849
            _mm256_cvtph_ps(
1850
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),
1851
            vop24);
1852
        // skip unnecessary prefetch of (&ip_next_T0[24])
1853
      }
1854
      if (!normalize_by_lengths || length == 0) {
1855
        _mm256_storeu_ps(&op[0], vop0);
1856
        _mm256_storeu_ps(&op[8], vop8);
1857
        _mm256_storeu_ps(&op[16], vop16);
1858
        _mm256_storeu_ps(&op[24], vop24);
1859
      } else {
1860
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1861
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1862
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1863
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
1864
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
1865
      }
1866
    }
1867
  } else if (block_size == 16) {
1868
    // unrolling 2 times
1869
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1870
      float* op = &out[rangeIndex * block_size];
1871
      __m256 vop0 = _mm256_setzero_ps();
1872
      __m256 vop8 = _mm256_setzero_ps();
1873
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
1874
        return false;
1875
      }
1876
      int64_t end_offset = offsets[rangeIndex + 1];
1877
      int64_t length = end_offset - offsets[rangeIndex];
1878
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1879
           ++dataInd) {
1880
        const int64_t idx = indices[dataInd];
1881
        if (idx < 0 || idx >= data_size) {
1882
          return false;
1883
        }
1884
        float wgt = 1.f;
1885
        if (weights) {
1886
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1887
        }
1888
        __m256 vwgt = _mm256_set1_ps(wgt);
1889
        const at::Half* ip = &input[idx * fused_block_size];
1890
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1891
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1892
            ? (dataInd + prefdist_T0)
1893
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1894
            : dataInd;
1895
        const int64_t idx_pref_T0 = indices[next_T0];
1896
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1897
          return false;
1898
        }
1899
        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1900
        vop0 = _mm256_fmadd_ps(
1901
            vwgt,
1902
            _mm256_cvtph_ps(
1903
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),
1904
            vop0);
1905
        _mm_prefetch(
1906
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
1907
        vop8 = _mm256_fmadd_ps(
1908
            vwgt,
1909
            _mm256_cvtph_ps(
1910
                _mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),
1911
            vop8);
1912
        // skip unnecessary prefetch of (&ip_next_T0[8])
1913
      }
1914
      if (!normalize_by_lengths || length == 0) {
1915
        _mm256_storeu_ps(&op[0], vop0);
1916
        _mm256_storeu_ps(&op[8], vop8);
1917
      } else {
1918
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
1919
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
1920
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
1921
      }
1922
    }
1923
  } else {
1924
    // generic code
1925
    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
1926
    alignas(64) at::Half vtmp1[8] = {0};
1927
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
1928
      float* op = &out[rangeIndex * block_size];
1929
      int64_t j = 0;
1930
      for (; j + 8 <= block_size; j += 8) {
1931
        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
1932
      }
1933
      for (; j < block_size; j++) {
1934
        op[j] = 0.0f;
1935
      }
1936
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
1937
        return false;
1938
      }
1939
      int64_t end_offset = offsets[rangeIndex + 1];
1940
      int64_t length = end_offset - offsets[rangeIndex];
1941
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
1942
           ++dataInd) {
1943
        const int64_t idx = indices[dataInd];
1944
        if (idx < 0 || idx >= data_size) {
1945
          return false;
1946
        }
1947
        float wgt = 1.f;
1948
        if (weights) {
1949
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
1950
        }
1951
        __m256 vwgt = _mm256_set1_ps(wgt);
1952
        const at::Half* ip = &input[idx * fused_block_size];
1953
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
1954
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1955
            ? (dataInd + prefdist_T0)
1956
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
1957
            : dataInd;
1958
        const int64_t idx_pref_T0 = indices[next_T0];
1959
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
1960
          return false;
1961
        }
1962
        const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
1963
        j = 0;
1964
        for (; j + 8 <= block_size; j += 8) {
1965
          _mm256_storeu_ps(
1966
              &op[j],
1967
              _mm256_fmadd_ps(
1968
                  vwgt,
1969
                  _mm256_cvtph_ps(_mm_loadu_si128(
1970
                      reinterpret_cast<const __m128i*>(&ip[j]))),
1971
                  _mm256_loadu_ps(&op[j])));
1972
          _mm_prefetch(
1973
              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
1974
        }
1975
        for (; j < block_size; j++) {
1976
          vtmp1[0] = ip[j];
1977
          __m256 vtmp2 =
1978
              _mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));
1979
          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
1980
        }
1981
      }
1982
      if (normalize_by_lengths && length) {
1983
        float len_inv = 1.0f / length;
1984
        __m256 vlen_inv = _mm256_set1_ps(len_inv);
1985
        j = 0;
1986
        for (; j + 8 <= block_size; j += 8) {
1987
          _mm256_storeu_ps(
1988
              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
1989
        }
1990
        for (; j < block_size; j++) {
1991
          op[j] = len_inv * op[j];
1992
        }
1993
      }
1994
    }
1995
  }
1996
  return dataInd == index_size;
1997
}
1998
bool EmbeddingLookupIdx_int64_t_half_float_false__avx2_fma(
1999
    const int64_t block_size,
2000
    const int64_t output_size,
2001
    const int64_t index_size,
2002
    const int64_t data_size,
2003
    const at::Half* input,
2004
    const int64_t* indices,
2005
    const int64_t* offsets,
2006
    const float* weights,
2007
    const float* scale_bias,
2008
    bool normalize_by_lengths,
2009
    float* out) {
2010
  return EmbeddingLookupIdx_int64_t_half_float__avx2_fma<false>(
2011
      block_size,
2012
      output_size,
2013
      index_size,
2014
      data_size,
2015
      input,
2016
      indices,
2017
      offsets,
2018
      weights,
2019
      scale_bias,
2020
      normalize_by_lengths,
2021
      out);
2022
}
2023
bool EmbeddingLookupIdx_int64_t_half_float_true__avx2_fma(
2024
    const int64_t block_size,
2025
    const int64_t output_size,
2026
    const int64_t index_size,
2027
    const int64_t data_size,
2028
    const at::Half* input,
2029
    const int64_t* indices,
2030
    const int64_t* offsets,
2031
    const float* weights,
2032
    const float* scale_bias,
2033
    bool normalize_by_lengths,
2034
    float* out) {
2035
  return EmbeddingLookupIdx_int64_t_half_float__avx2_fma<true>(
2036
      block_size,
2037
      output_size,
2038
      index_size,
2039
      data_size,
2040
      input,
2041
      indices,
2042
      offsets,
2043
      weights,
2044
      scale_bias,
2045
      normalize_by_lengths,
2046
      out);
2047
}
2048

2049
template <bool IS_WEIGHT_POSITIONAL>
2050
static bool EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma(
2051
    const int64_t block_size,
2052
    const int64_t output_size,
2053
    const int64_t index_size,
2054
    const int64_t data_size,
2055
    const at::BFloat16* input,
2056
    const int* indices,
2057
    const int* offsets,
2058
    const float* weights,
2059
    const float* scale_bias,
2060
    bool normalize_by_lengths,
2061
    float* out) {
2062
  const int prefdist_T0 = 16;
2063
  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2064
  const int fused_block_size = block_size + 0;
2065
  int64_t dataInd = 0;
2066
  if (block_size == 128) {
2067
    // unrolling 16 times
2068
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2069
      float* op = &out[rangeIndex * block_size];
2070
      __m256 vop0 = _mm256_setzero_ps();
2071
      __m256 vop8 = _mm256_setzero_ps();
2072
      __m256 vop16 = _mm256_setzero_ps();
2073
      __m256 vop24 = _mm256_setzero_ps();
2074
      __m256 vop32 = _mm256_setzero_ps();
2075
      __m256 vop40 = _mm256_setzero_ps();
2076
      __m256 vop48 = _mm256_setzero_ps();
2077
      __m256 vop56 = _mm256_setzero_ps();
2078
      __m256 vop64 = _mm256_setzero_ps();
2079
      __m256 vop72 = _mm256_setzero_ps();
2080
      __m256 vop80 = _mm256_setzero_ps();
2081
      __m256 vop88 = _mm256_setzero_ps();
2082
      __m256 vop96 = _mm256_setzero_ps();
2083
      __m256 vop104 = _mm256_setzero_ps();
2084
      __m256 vop112 = _mm256_setzero_ps();
2085
      __m256 vop120 = _mm256_setzero_ps();
2086
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
2087
        return false;
2088
      }
2089
      int64_t end_offset = offsets[rangeIndex + 1];
2090
      int64_t length = end_offset - offsets[rangeIndex];
2091
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2092
           ++dataInd) {
2093
        const int idx = indices[dataInd];
2094
        if (idx < 0 || idx >= data_size) {
2095
          return false;
2096
        }
2097
        float wgt = 1.f;
2098
        if (weights) {
2099
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2100
        }
2101
        __m256 vwgt = _mm256_set1_ps(wgt);
2102
        const at::BFloat16* ip = &input[idx * fused_block_size];
2103
        const int next_T0 = (dataInd < index_size - prefdist_T0)
2104
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2105
            ? (dataInd + prefdist_T0)
2106
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2107
            : dataInd;
2108
        const int idx_pref_T0 = indices[next_T0];
2109
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2110
          return false;
2111
        }
2112
        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2113
        vop0 = _mm256_fmadd_ps(
2114
            vwgt,
2115
            _mm256_castsi256_ps(_mm256_slli_epi32(
2116
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2117
                    reinterpret_cast<const __m128i*>(ip + (0)))),
2118
                16)),
2119
            vop0);
2120
        _mm_prefetch(
2121
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2122
        vop8 = _mm256_fmadd_ps(
2123
            vwgt,
2124
            _mm256_castsi256_ps(_mm256_slli_epi32(
2125
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2126
                    reinterpret_cast<const __m128i*>(ip + (8)))),
2127
                16)),
2128
            vop8);
2129
        // skip unnecessary prefetch of (&ip_next_T0[8])
2130
        vop16 = _mm256_fmadd_ps(
2131
            vwgt,
2132
            _mm256_castsi256_ps(_mm256_slli_epi32(
2133
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2134
                    reinterpret_cast<const __m128i*>(ip + (16)))),
2135
                16)),
2136
            vop16);
2137
        // skip unnecessary prefetch of (&ip_next_T0[16])
2138
        vop24 = _mm256_fmadd_ps(
2139
            vwgt,
2140
            _mm256_castsi256_ps(_mm256_slli_epi32(
2141
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2142
                    reinterpret_cast<const __m128i*>(ip + (24)))),
2143
                16)),
2144
            vop24);
2145
        // skip unnecessary prefetch of (&ip_next_T0[24])
2146
        vop32 = _mm256_fmadd_ps(
2147
            vwgt,
2148
            _mm256_castsi256_ps(_mm256_slli_epi32(
2149
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2150
                    reinterpret_cast<const __m128i*>(ip + (32)))),
2151
                16)),
2152
            vop32);
2153
        _mm_prefetch(
2154
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2155
        vop40 = _mm256_fmadd_ps(
2156
            vwgt,
2157
            _mm256_castsi256_ps(_mm256_slli_epi32(
2158
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2159
                    reinterpret_cast<const __m128i*>(ip + (40)))),
2160
                16)),
2161
            vop40);
2162
        // skip unnecessary prefetch of (&ip_next_T0[40])
2163
        vop48 = _mm256_fmadd_ps(
2164
            vwgt,
2165
            _mm256_castsi256_ps(_mm256_slli_epi32(
2166
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2167
                    reinterpret_cast<const __m128i*>(ip + (48)))),
2168
                16)),
2169
            vop48);
2170
        // skip unnecessary prefetch of (&ip_next_T0[48])
2171
        vop56 = _mm256_fmadd_ps(
2172
            vwgt,
2173
            _mm256_castsi256_ps(_mm256_slli_epi32(
2174
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2175
                    reinterpret_cast<const __m128i*>(ip + (56)))),
2176
                16)),
2177
            vop56);
2178
        // skip unnecessary prefetch of (&ip_next_T0[56])
2179
        vop64 = _mm256_fmadd_ps(
2180
            vwgt,
2181
            _mm256_castsi256_ps(_mm256_slli_epi32(
2182
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2183
                    reinterpret_cast<const __m128i*>(ip + (64)))),
2184
                16)),
2185
            vop64);
2186
        _mm_prefetch(
2187
            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2188
        vop72 = _mm256_fmadd_ps(
2189
            vwgt,
2190
            _mm256_castsi256_ps(_mm256_slli_epi32(
2191
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2192
                    reinterpret_cast<const __m128i*>(ip + (72)))),
2193
                16)),
2194
            vop72);
2195
        // skip unnecessary prefetch of (&ip_next_T0[72])
2196
        vop80 = _mm256_fmadd_ps(
2197
            vwgt,
2198
            _mm256_castsi256_ps(_mm256_slli_epi32(
2199
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2200
                    reinterpret_cast<const __m128i*>(ip + (80)))),
2201
                16)),
2202
            vop80);
2203
        // skip unnecessary prefetch of (&ip_next_T0[80])
2204
        vop88 = _mm256_fmadd_ps(
2205
            vwgt,
2206
            _mm256_castsi256_ps(_mm256_slli_epi32(
2207
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2208
                    reinterpret_cast<const __m128i*>(ip + (88)))),
2209
                16)),
2210
            vop88);
2211
        // skip unnecessary prefetch of (&ip_next_T0[88])
2212
        vop96 = _mm256_fmadd_ps(
2213
            vwgt,
2214
            _mm256_castsi256_ps(_mm256_slli_epi32(
2215
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2216
                    reinterpret_cast<const __m128i*>(ip + (96)))),
2217
                16)),
2218
            vop96);
2219
        _mm_prefetch(
2220
            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
2221
        vop104 = _mm256_fmadd_ps(
2222
            vwgt,
2223
            _mm256_castsi256_ps(_mm256_slli_epi32(
2224
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2225
                    reinterpret_cast<const __m128i*>(ip + (104)))),
2226
                16)),
2227
            vop104);
2228
        // skip unnecessary prefetch of (&ip_next_T0[104])
2229
        vop112 = _mm256_fmadd_ps(
2230
            vwgt,
2231
            _mm256_castsi256_ps(_mm256_slli_epi32(
2232
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2233
                    reinterpret_cast<const __m128i*>(ip + (112)))),
2234
                16)),
2235
            vop112);
2236
        // skip unnecessary prefetch of (&ip_next_T0[112])
2237
        vop120 = _mm256_fmadd_ps(
2238
            vwgt,
2239
            _mm256_castsi256_ps(_mm256_slli_epi32(
2240
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2241
                    reinterpret_cast<const __m128i*>(ip + (120)))),
2242
                16)),
2243
            vop120);
2244
        // skip unnecessary prefetch of (&ip_next_T0[120])
2245
      }
2246
      if (!normalize_by_lengths || length == 0) {
2247
        _mm256_storeu_ps(&op[0], vop0);
2248
        _mm256_storeu_ps(&op[8], vop8);
2249
        _mm256_storeu_ps(&op[16], vop16);
2250
        _mm256_storeu_ps(&op[24], vop24);
2251
        _mm256_storeu_ps(&op[32], vop32);
2252
        _mm256_storeu_ps(&op[40], vop40);
2253
        _mm256_storeu_ps(&op[48], vop48);
2254
        _mm256_storeu_ps(&op[56], vop56);
2255
        _mm256_storeu_ps(&op[64], vop64);
2256
        _mm256_storeu_ps(&op[72], vop72);
2257
        _mm256_storeu_ps(&op[80], vop80);
2258
        _mm256_storeu_ps(&op[88], vop88);
2259
        _mm256_storeu_ps(&op[96], vop96);
2260
        _mm256_storeu_ps(&op[104], vop104);
2261
        _mm256_storeu_ps(&op[112], vop112);
2262
        _mm256_storeu_ps(&op[120], vop120);
2263
      } else {
2264
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2265
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2266
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2267
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2268
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2269
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2270
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2271
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2272
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2273
        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2274
        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2275
        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2276
        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2277
        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2278
        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2279
        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2280
        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2281
      }
2282
    }
2283
  } else if (block_size == 64) {
2284
    // unrolling 8 times
2285
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2286
      float* op = &out[rangeIndex * block_size];
2287
      __m256 vop0 = _mm256_setzero_ps();
2288
      __m256 vop8 = _mm256_setzero_ps();
2289
      __m256 vop16 = _mm256_setzero_ps();
2290
      __m256 vop24 = _mm256_setzero_ps();
2291
      __m256 vop32 = _mm256_setzero_ps();
2292
      __m256 vop40 = _mm256_setzero_ps();
2293
      __m256 vop48 = _mm256_setzero_ps();
2294
      __m256 vop56 = _mm256_setzero_ps();
2295
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
2296
        return false;
2297
      }
2298
      int64_t end_offset = offsets[rangeIndex + 1];
2299
      int64_t length = end_offset - offsets[rangeIndex];
2300
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2301
           ++dataInd) {
2302
        const int idx = indices[dataInd];
2303
        if (idx < 0 || idx >= data_size) {
2304
          return false;
2305
        }
2306
        float wgt = 1.f;
2307
        if (weights) {
2308
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2309
        }
2310
        __m256 vwgt = _mm256_set1_ps(wgt);
2311
        const at::BFloat16* ip = &input[idx * fused_block_size];
2312
        const int next_T0 = (dataInd < index_size - prefdist_T0)
2313
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2314
            ? (dataInd + prefdist_T0)
2315
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2316
            : dataInd;
2317
        const int idx_pref_T0 = indices[next_T0];
2318
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2319
          return false;
2320
        }
2321
        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2322
        vop0 = _mm256_fmadd_ps(
2323
            vwgt,
2324
            _mm256_castsi256_ps(_mm256_slli_epi32(
2325
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2326
                    reinterpret_cast<const __m128i*>(ip + (0)))),
2327
                16)),
2328
            vop0);
2329
        _mm_prefetch(
2330
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2331
        vop8 = _mm256_fmadd_ps(
2332
            vwgt,
2333
            _mm256_castsi256_ps(_mm256_slli_epi32(
2334
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2335
                    reinterpret_cast<const __m128i*>(ip + (8)))),
2336
                16)),
2337
            vop8);
2338
        // skip unnecessary prefetch of (&ip_next_T0[8])
2339
        vop16 = _mm256_fmadd_ps(
2340
            vwgt,
2341
            _mm256_castsi256_ps(_mm256_slli_epi32(
2342
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2343
                    reinterpret_cast<const __m128i*>(ip + (16)))),
2344
                16)),
2345
            vop16);
2346
        // skip unnecessary prefetch of (&ip_next_T0[16])
2347
        vop24 = _mm256_fmadd_ps(
2348
            vwgt,
2349
            _mm256_castsi256_ps(_mm256_slli_epi32(
2350
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2351
                    reinterpret_cast<const __m128i*>(ip + (24)))),
2352
                16)),
2353
            vop24);
2354
        // skip unnecessary prefetch of (&ip_next_T0[24])
2355
        vop32 = _mm256_fmadd_ps(
2356
            vwgt,
2357
            _mm256_castsi256_ps(_mm256_slli_epi32(
2358
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2359
                    reinterpret_cast<const __m128i*>(ip + (32)))),
2360
                16)),
2361
            vop32);
2362
        _mm_prefetch(
2363
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2364
        vop40 = _mm256_fmadd_ps(
2365
            vwgt,
2366
            _mm256_castsi256_ps(_mm256_slli_epi32(
2367
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2368
                    reinterpret_cast<const __m128i*>(ip + (40)))),
2369
                16)),
2370
            vop40);
2371
        // skip unnecessary prefetch of (&ip_next_T0[40])
2372
        vop48 = _mm256_fmadd_ps(
2373
            vwgt,
2374
            _mm256_castsi256_ps(_mm256_slli_epi32(
2375
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2376
                    reinterpret_cast<const __m128i*>(ip + (48)))),
2377
                16)),
2378
            vop48);
2379
        // skip unnecessary prefetch of (&ip_next_T0[48])
2380
        vop56 = _mm256_fmadd_ps(
2381
            vwgt,
2382
            _mm256_castsi256_ps(_mm256_slli_epi32(
2383
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2384
                    reinterpret_cast<const __m128i*>(ip + (56)))),
2385
                16)),
2386
            vop56);
2387
        // skip unnecessary prefetch of (&ip_next_T0[56])
2388
      }
2389
      if (!normalize_by_lengths || length == 0) {
2390
        _mm256_storeu_ps(&op[0], vop0);
2391
        _mm256_storeu_ps(&op[8], vop8);
2392
        _mm256_storeu_ps(&op[16], vop16);
2393
        _mm256_storeu_ps(&op[24], vop24);
2394
        _mm256_storeu_ps(&op[32], vop32);
2395
        _mm256_storeu_ps(&op[40], vop40);
2396
        _mm256_storeu_ps(&op[48], vop48);
2397
        _mm256_storeu_ps(&op[56], vop56);
2398
      } else {
2399
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2400
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2401
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2402
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2403
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2404
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2405
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2406
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2407
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2408
      }
2409
    }
2410
  } else if (block_size == 32) {
2411
    // unrolling 4 times
2412
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2413
      float* op = &out[rangeIndex * block_size];
2414
      __m256 vop0 = _mm256_setzero_ps();
2415
      __m256 vop8 = _mm256_setzero_ps();
2416
      __m256 vop16 = _mm256_setzero_ps();
2417
      __m256 vop24 = _mm256_setzero_ps();
2418
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
2419
        return false;
2420
      }
2421
      int64_t end_offset = offsets[rangeIndex + 1];
2422
      int64_t length = end_offset - offsets[rangeIndex];
2423
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2424
           ++dataInd) {
2425
        const int idx = indices[dataInd];
2426
        if (idx < 0 || idx >= data_size) {
2427
          return false;
2428
        }
2429
        float wgt = 1.f;
2430
        if (weights) {
2431
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2432
        }
2433
        __m256 vwgt = _mm256_set1_ps(wgt);
2434
        const at::BFloat16* ip = &input[idx * fused_block_size];
2435
        const int next_T0 = (dataInd < index_size - prefdist_T0)
2436
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2437
            ? (dataInd + prefdist_T0)
2438
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2439
            : dataInd;
2440
        const int idx_pref_T0 = indices[next_T0];
2441
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2442
          return false;
2443
        }
2444
        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2445
        vop0 = _mm256_fmadd_ps(
2446
            vwgt,
2447
            _mm256_castsi256_ps(_mm256_slli_epi32(
2448
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2449
                    reinterpret_cast<const __m128i*>(ip + (0)))),
2450
                16)),
2451
            vop0);
2452
        _mm_prefetch(
2453
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2454
        vop8 = _mm256_fmadd_ps(
2455
            vwgt,
2456
            _mm256_castsi256_ps(_mm256_slli_epi32(
2457
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2458
                    reinterpret_cast<const __m128i*>(ip + (8)))),
2459
                16)),
2460
            vop8);
2461
        // skip unnecessary prefetch of (&ip_next_T0[8])
2462
        vop16 = _mm256_fmadd_ps(
2463
            vwgt,
2464
            _mm256_castsi256_ps(_mm256_slli_epi32(
2465
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2466
                    reinterpret_cast<const __m128i*>(ip + (16)))),
2467
                16)),
2468
            vop16);
2469
        // skip unnecessary prefetch of (&ip_next_T0[16])
2470
        vop24 = _mm256_fmadd_ps(
2471
            vwgt,
2472
            _mm256_castsi256_ps(_mm256_slli_epi32(
2473
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2474
                    reinterpret_cast<const __m128i*>(ip + (24)))),
2475
                16)),
2476
            vop24);
2477
        // skip unnecessary prefetch of (&ip_next_T0[24])
2478
      }
2479
      if (!normalize_by_lengths || length == 0) {
2480
        _mm256_storeu_ps(&op[0], vop0);
2481
        _mm256_storeu_ps(&op[8], vop8);
2482
        _mm256_storeu_ps(&op[16], vop16);
2483
        _mm256_storeu_ps(&op[24], vop24);
2484
      } else {
2485
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2486
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2487
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2488
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2489
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2490
      }
2491
    }
2492
  } else if (block_size == 16) {
2493
    // unrolling 2 times
2494
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2495
      float* op = &out[rangeIndex * block_size];
2496
      __m256 vop0 = _mm256_setzero_ps();
2497
      __m256 vop8 = _mm256_setzero_ps();
2498
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
2499
        return false;
2500
      }
2501
      int64_t end_offset = offsets[rangeIndex + 1];
2502
      int64_t length = end_offset - offsets[rangeIndex];
2503
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2504
           ++dataInd) {
2505
        const int idx = indices[dataInd];
2506
        if (idx < 0 || idx >= data_size) {
2507
          return false;
2508
        }
2509
        float wgt = 1.f;
2510
        if (weights) {
2511
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2512
        }
2513
        __m256 vwgt = _mm256_set1_ps(wgt);
2514
        const at::BFloat16* ip = &input[idx * fused_block_size];
2515
        const int next_T0 = (dataInd < index_size - prefdist_T0)
2516
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2517
            ? (dataInd + prefdist_T0)
2518
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2519
            : dataInd;
2520
        const int idx_pref_T0 = indices[next_T0];
2521
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2522
          return false;
2523
        }
2524
        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2525
        vop0 = _mm256_fmadd_ps(
2526
            vwgt,
2527
            _mm256_castsi256_ps(_mm256_slli_epi32(
2528
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2529
                    reinterpret_cast<const __m128i*>(ip + (0)))),
2530
                16)),
2531
            vop0);
2532
        _mm_prefetch(
2533
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2534
        vop8 = _mm256_fmadd_ps(
2535
            vwgt,
2536
            _mm256_castsi256_ps(_mm256_slli_epi32(
2537
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2538
                    reinterpret_cast<const __m128i*>(ip + (8)))),
2539
                16)),
2540
            vop8);
2541
        // skip unnecessary prefetch of (&ip_next_T0[8])
2542
      }
2543
      if (!normalize_by_lengths || length == 0) {
2544
        _mm256_storeu_ps(&op[0], vop0);
2545
        _mm256_storeu_ps(&op[8], vop8);
2546
      } else {
2547
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2548
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2549
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2550
      }
2551
    }
2552
  } else {
2553
    // generic code
2554
    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
2555
    alignas(64) at::BFloat16 vtmp1[8] = {0};
2556
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2557
      float* op = &out[rangeIndex * block_size];
2558
      int64_t j = 0;
2559
      for (; j + 8 <= block_size; j += 8) {
2560
        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
2561
      }
2562
      for (; j < block_size; j++) {
2563
        op[j] = 0.0f;
2564
      }
2565
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
2566
        return false;
2567
      }
2568
      int64_t end_offset = offsets[rangeIndex + 1];
2569
      int64_t length = end_offset - offsets[rangeIndex];
2570
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2571
           ++dataInd) {
2572
        const int idx = indices[dataInd];
2573
        if (idx < 0 || idx >= data_size) {
2574
          return false;
2575
        }
2576
        float wgt = 1.f;
2577
        if (weights) {
2578
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2579
        }
2580
        __m256 vwgt = _mm256_set1_ps(wgt);
2581
        const at::BFloat16* ip = &input[idx * fused_block_size];
2582
        const int next_T0 = (dataInd < index_size - prefdist_T0)
2583
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2584
            ? (dataInd + prefdist_T0)
2585
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2586
            : dataInd;
2587
        const int idx_pref_T0 = indices[next_T0];
2588
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2589
          return false;
2590
        }
2591
        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2592
        j = 0;
2593
        for (; j + 8 <= block_size; j += 8) {
2594
          _mm256_storeu_ps(
2595
              &op[j],
2596
              _mm256_fmadd_ps(
2597
                  vwgt,
2598
                  _mm256_castsi256_ps(_mm256_slli_epi32(
2599
                      _mm256_cvtepu16_epi32(_mm_loadu_si128(
2600
                          reinterpret_cast<const __m128i*>(&ip[j]))),
2601
                      16)),
2602
                  _mm256_loadu_ps(&op[j])));
2603
          _mm_prefetch(
2604
              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
2605
        }
2606
        for (; j < block_size; j++) {
2607
          vtmp1[0] = ip[j];
2608
          __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(
2609
              _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),
2610
              16));
2611
          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
2612
        }
2613
      }
2614
      if (normalize_by_lengths && length) {
2615
        float len_inv = 1.0f / length;
2616
        __m256 vlen_inv = _mm256_set1_ps(len_inv);
2617
        j = 0;
2618
        for (; j + 8 <= block_size; j += 8) {
2619
          _mm256_storeu_ps(
2620
              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
2621
        }
2622
        for (; j < block_size; j++) {
2623
          op[j] = len_inv * op[j];
2624
        }
2625
      }
2626
    }
2627
  }
2628
  return dataInd == index_size;
2629
}
2630
bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__avx2_fma(
2631
    const int64_t block_size,
2632
    const int64_t output_size,
2633
    const int64_t index_size,
2634
    const int64_t data_size,
2635
    const at::BFloat16* input,
2636
    const int* indices,
2637
    const int* offsets,
2638
    const float* weights,
2639
    const float* scale_bias,
2640
    bool normalize_by_lengths,
2641
    float* out) {
2642
  return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<false>(
2643
      block_size,
2644
      output_size,
2645
      index_size,
2646
      data_size,
2647
      input,
2648
      indices,
2649
      offsets,
2650
      weights,
2651
      scale_bias,
2652
      normalize_by_lengths,
2653
      out);
2654
}
2655
bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__avx2_fma(
2656
    const int64_t block_size,
2657
    const int64_t output_size,
2658
    const int64_t index_size,
2659
    const int64_t data_size,
2660
    const at::BFloat16* input,
2661
    const int* indices,
2662
    const int* offsets,
2663
    const float* weights,
2664
    const float* scale_bias,
2665
    bool normalize_by_lengths,
2666
    float* out) {
2667
  return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<true>(
2668
      block_size,
2669
      output_size,
2670
      index_size,
2671
      data_size,
2672
      input,
2673
      indices,
2674
      offsets,
2675
      weights,
2676
      scale_bias,
2677
      normalize_by_lengths,
2678
      out);
2679
}
2680

2681
template <bool IS_WEIGHT_POSITIONAL>
2682
static bool EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma(
2683
    const int64_t block_size,
2684
    const int64_t output_size,
2685
    const int64_t index_size,
2686
    const int64_t data_size,
2687
    const at::BFloat16* input,
2688
    const int64_t* indices,
2689
    const int64_t* offsets,
2690
    const float* weights,
2691
    const float* scale_bias,
2692
    bool normalize_by_lengths,
2693
    float* out) {
2694
  const int64_t prefdist_T0 = 16;
2695
  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2696
  const int64_t fused_block_size = block_size + 0;
2697
  int64_t dataInd = 0;
2698
  if (block_size == 128) {
2699
    // unrolling 16 times
2700
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2701
      float* op = &out[rangeIndex * block_size];
2702
      __m256 vop0 = _mm256_setzero_ps();
2703
      __m256 vop8 = _mm256_setzero_ps();
2704
      __m256 vop16 = _mm256_setzero_ps();
2705
      __m256 vop24 = _mm256_setzero_ps();
2706
      __m256 vop32 = _mm256_setzero_ps();
2707
      __m256 vop40 = _mm256_setzero_ps();
2708
      __m256 vop48 = _mm256_setzero_ps();
2709
      __m256 vop56 = _mm256_setzero_ps();
2710
      __m256 vop64 = _mm256_setzero_ps();
2711
      __m256 vop72 = _mm256_setzero_ps();
2712
      __m256 vop80 = _mm256_setzero_ps();
2713
      __m256 vop88 = _mm256_setzero_ps();
2714
      __m256 vop96 = _mm256_setzero_ps();
2715
      __m256 vop104 = _mm256_setzero_ps();
2716
      __m256 vop112 = _mm256_setzero_ps();
2717
      __m256 vop120 = _mm256_setzero_ps();
2718
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
2719
        return false;
2720
      }
2721
      int64_t end_offset = offsets[rangeIndex + 1];
2722
      int64_t length = end_offset - offsets[rangeIndex];
2723
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2724
           ++dataInd) {
2725
        const int64_t idx = indices[dataInd];
2726
        if (idx < 0 || idx >= data_size) {
2727
          return false;
2728
        }
2729
        float wgt = 1.f;
2730
        if (weights) {
2731
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2732
        }
2733
        __m256 vwgt = _mm256_set1_ps(wgt);
2734
        const at::BFloat16* ip = &input[idx * fused_block_size];
2735
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2736
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2737
            ? (dataInd + prefdist_T0)
2738
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2739
            : dataInd;
2740
        const int64_t idx_pref_T0 = indices[next_T0];
2741
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2742
          return false;
2743
        }
2744
        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2745
        vop0 = _mm256_fmadd_ps(
2746
            vwgt,
2747
            _mm256_castsi256_ps(_mm256_slli_epi32(
2748
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2749
                    reinterpret_cast<const __m128i*>(ip + (0)))),
2750
                16)),
2751
            vop0);
2752
        _mm_prefetch(
2753
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2754
        vop8 = _mm256_fmadd_ps(
2755
            vwgt,
2756
            _mm256_castsi256_ps(_mm256_slli_epi32(
2757
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2758
                    reinterpret_cast<const __m128i*>(ip + (8)))),
2759
                16)),
2760
            vop8);
2761
        // skip unnecessary prefetch of (&ip_next_T0[8])
2762
        vop16 = _mm256_fmadd_ps(
2763
            vwgt,
2764
            _mm256_castsi256_ps(_mm256_slli_epi32(
2765
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2766
                    reinterpret_cast<const __m128i*>(ip + (16)))),
2767
                16)),
2768
            vop16);
2769
        // skip unnecessary prefetch of (&ip_next_T0[16])
2770
        vop24 = _mm256_fmadd_ps(
2771
            vwgt,
2772
            _mm256_castsi256_ps(_mm256_slli_epi32(
2773
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2774
                    reinterpret_cast<const __m128i*>(ip + (24)))),
2775
                16)),
2776
            vop24);
2777
        // skip unnecessary prefetch of (&ip_next_T0[24])
2778
        vop32 = _mm256_fmadd_ps(
2779
            vwgt,
2780
            _mm256_castsi256_ps(_mm256_slli_epi32(
2781
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2782
                    reinterpret_cast<const __m128i*>(ip + (32)))),
2783
                16)),
2784
            vop32);
2785
        _mm_prefetch(
2786
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2787
        vop40 = _mm256_fmadd_ps(
2788
            vwgt,
2789
            _mm256_castsi256_ps(_mm256_slli_epi32(
2790
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2791
                    reinterpret_cast<const __m128i*>(ip + (40)))),
2792
                16)),
2793
            vop40);
2794
        // skip unnecessary prefetch of (&ip_next_T0[40])
2795
        vop48 = _mm256_fmadd_ps(
2796
            vwgt,
2797
            _mm256_castsi256_ps(_mm256_slli_epi32(
2798
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2799
                    reinterpret_cast<const __m128i*>(ip + (48)))),
2800
                16)),
2801
            vop48);
2802
        // skip unnecessary prefetch of (&ip_next_T0[48])
2803
        vop56 = _mm256_fmadd_ps(
2804
            vwgt,
2805
            _mm256_castsi256_ps(_mm256_slli_epi32(
2806
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2807
                    reinterpret_cast<const __m128i*>(ip + (56)))),
2808
                16)),
2809
            vop56);
2810
        // skip unnecessary prefetch of (&ip_next_T0[56])
2811
        vop64 = _mm256_fmadd_ps(
2812
            vwgt,
2813
            _mm256_castsi256_ps(_mm256_slli_epi32(
2814
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2815
                    reinterpret_cast<const __m128i*>(ip + (64)))),
2816
                16)),
2817
            vop64);
2818
        _mm_prefetch(
2819
            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
2820
        vop72 = _mm256_fmadd_ps(
2821
            vwgt,
2822
            _mm256_castsi256_ps(_mm256_slli_epi32(
2823
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2824
                    reinterpret_cast<const __m128i*>(ip + (72)))),
2825
                16)),
2826
            vop72);
2827
        // skip unnecessary prefetch of (&ip_next_T0[72])
2828
        vop80 = _mm256_fmadd_ps(
2829
            vwgt,
2830
            _mm256_castsi256_ps(_mm256_slli_epi32(
2831
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2832
                    reinterpret_cast<const __m128i*>(ip + (80)))),
2833
                16)),
2834
            vop80);
2835
        // skip unnecessary prefetch of (&ip_next_T0[80])
2836
        vop88 = _mm256_fmadd_ps(
2837
            vwgt,
2838
            _mm256_castsi256_ps(_mm256_slli_epi32(
2839
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2840
                    reinterpret_cast<const __m128i*>(ip + (88)))),
2841
                16)),
2842
            vop88);
2843
        // skip unnecessary prefetch of (&ip_next_T0[88])
2844
        vop96 = _mm256_fmadd_ps(
2845
            vwgt,
2846
            _mm256_castsi256_ps(_mm256_slli_epi32(
2847
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2848
                    reinterpret_cast<const __m128i*>(ip + (96)))),
2849
                16)),
2850
            vop96);
2851
        _mm_prefetch(
2852
            reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);
2853
        vop104 = _mm256_fmadd_ps(
2854
            vwgt,
2855
            _mm256_castsi256_ps(_mm256_slli_epi32(
2856
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2857
                    reinterpret_cast<const __m128i*>(ip + (104)))),
2858
                16)),
2859
            vop104);
2860
        // skip unnecessary prefetch of (&ip_next_T0[104])
2861
        vop112 = _mm256_fmadd_ps(
2862
            vwgt,
2863
            _mm256_castsi256_ps(_mm256_slli_epi32(
2864
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2865
                    reinterpret_cast<const __m128i*>(ip + (112)))),
2866
                16)),
2867
            vop112);
2868
        // skip unnecessary prefetch of (&ip_next_T0[112])
2869
        vop120 = _mm256_fmadd_ps(
2870
            vwgt,
2871
            _mm256_castsi256_ps(_mm256_slli_epi32(
2872
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2873
                    reinterpret_cast<const __m128i*>(ip + (120)))),
2874
                16)),
2875
            vop120);
2876
        // skip unnecessary prefetch of (&ip_next_T0[120])
2877
      }
2878
      if (!normalize_by_lengths || length == 0) {
2879
        _mm256_storeu_ps(&op[0], vop0);
2880
        _mm256_storeu_ps(&op[8], vop8);
2881
        _mm256_storeu_ps(&op[16], vop16);
2882
        _mm256_storeu_ps(&op[24], vop24);
2883
        _mm256_storeu_ps(&op[32], vop32);
2884
        _mm256_storeu_ps(&op[40], vop40);
2885
        _mm256_storeu_ps(&op[48], vop48);
2886
        _mm256_storeu_ps(&op[56], vop56);
2887
        _mm256_storeu_ps(&op[64], vop64);
2888
        _mm256_storeu_ps(&op[72], vop72);
2889
        _mm256_storeu_ps(&op[80], vop80);
2890
        _mm256_storeu_ps(&op[88], vop88);
2891
        _mm256_storeu_ps(&op[96], vop96);
2892
        _mm256_storeu_ps(&op[104], vop104);
2893
        _mm256_storeu_ps(&op[112], vop112);
2894
        _mm256_storeu_ps(&op[120], vop120);
2895
      } else {
2896
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
2897
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
2898
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
2899
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
2900
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
2901
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
2902
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
2903
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
2904
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
2905
        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
2906
        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
2907
        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
2908
        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
2909
        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
2910
        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
2911
        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
2912
        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
2913
      }
2914
    }
2915
  } else if (block_size == 64) {
2916
    // unrolling 8 times
2917
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
2918
      float* op = &out[rangeIndex * block_size];
2919
      __m256 vop0 = _mm256_setzero_ps();
2920
      __m256 vop8 = _mm256_setzero_ps();
2921
      __m256 vop16 = _mm256_setzero_ps();
2922
      __m256 vop24 = _mm256_setzero_ps();
2923
      __m256 vop32 = _mm256_setzero_ps();
2924
      __m256 vop40 = _mm256_setzero_ps();
2925
      __m256 vop48 = _mm256_setzero_ps();
2926
      __m256 vop56 = _mm256_setzero_ps();
2927
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
2928
        return false;
2929
      }
2930
      int64_t end_offset = offsets[rangeIndex + 1];
2931
      int64_t length = end_offset - offsets[rangeIndex];
2932
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
2933
           ++dataInd) {
2934
        const int64_t idx = indices[dataInd];
2935
        if (idx < 0 || idx >= data_size) {
2936
          return false;
2937
        }
2938
        float wgt = 1.f;
2939
        if (weights) {
2940
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
2941
        }
2942
        __m256 vwgt = _mm256_set1_ps(wgt);
2943
        const at::BFloat16* ip = &input[idx * fused_block_size];
2944
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
2945
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2946
            ? (dataInd + prefdist_T0)
2947
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
2948
            : dataInd;
2949
        const int64_t idx_pref_T0 = indices[next_T0];
2950
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
2951
          return false;
2952
        }
2953
        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
2954
        vop0 = _mm256_fmadd_ps(
2955
            vwgt,
2956
            _mm256_castsi256_ps(_mm256_slli_epi32(
2957
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2958
                    reinterpret_cast<const __m128i*>(ip + (0)))),
2959
                16)),
2960
            vop0);
2961
        _mm_prefetch(
2962
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
2963
        vop8 = _mm256_fmadd_ps(
2964
            vwgt,
2965
            _mm256_castsi256_ps(_mm256_slli_epi32(
2966
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2967
                    reinterpret_cast<const __m128i*>(ip + (8)))),
2968
                16)),
2969
            vop8);
2970
        // skip unnecessary prefetch of (&ip_next_T0[8])
2971
        vop16 = _mm256_fmadd_ps(
2972
            vwgt,
2973
            _mm256_castsi256_ps(_mm256_slli_epi32(
2974
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2975
                    reinterpret_cast<const __m128i*>(ip + (16)))),
2976
                16)),
2977
            vop16);
2978
        // skip unnecessary prefetch of (&ip_next_T0[16])
2979
        vop24 = _mm256_fmadd_ps(
2980
            vwgt,
2981
            _mm256_castsi256_ps(_mm256_slli_epi32(
2982
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2983
                    reinterpret_cast<const __m128i*>(ip + (24)))),
2984
                16)),
2985
            vop24);
2986
        // skip unnecessary prefetch of (&ip_next_T0[24])
2987
        vop32 = _mm256_fmadd_ps(
2988
            vwgt,
2989
            _mm256_castsi256_ps(_mm256_slli_epi32(
2990
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
2991
                    reinterpret_cast<const __m128i*>(ip + (32)))),
2992
                16)),
2993
            vop32);
2994
        _mm_prefetch(
2995
            reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);
2996
        vop40 = _mm256_fmadd_ps(
2997
            vwgt,
2998
            _mm256_castsi256_ps(_mm256_slli_epi32(
2999
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
3000
                    reinterpret_cast<const __m128i*>(ip + (40)))),
3001
                16)),
3002
            vop40);
3003
        // skip unnecessary prefetch of (&ip_next_T0[40])
3004
        vop48 = _mm256_fmadd_ps(
3005
            vwgt,
3006
            _mm256_castsi256_ps(_mm256_slli_epi32(
3007
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
3008
                    reinterpret_cast<const __m128i*>(ip + (48)))),
3009
                16)),
3010
            vop48);
3011
        // skip unnecessary prefetch of (&ip_next_T0[48])
3012
        vop56 = _mm256_fmadd_ps(
3013
            vwgt,
3014
            _mm256_castsi256_ps(_mm256_slli_epi32(
3015
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
3016
                    reinterpret_cast<const __m128i*>(ip + (56)))),
3017
                16)),
3018
            vop56);
3019
        // skip unnecessary prefetch of (&ip_next_T0[56])
3020
      }
3021
      if (!normalize_by_lengths || length == 0) {
3022
        _mm256_storeu_ps(&op[0], vop0);
3023
        _mm256_storeu_ps(&op[8], vop8);
3024
        _mm256_storeu_ps(&op[16], vop16);
3025
        _mm256_storeu_ps(&op[24], vop24);
3026
        _mm256_storeu_ps(&op[32], vop32);
3027
        _mm256_storeu_ps(&op[40], vop40);
3028
        _mm256_storeu_ps(&op[48], vop48);
3029
        _mm256_storeu_ps(&op[56], vop56);
3030
      } else {
3031
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3032
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3033
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3034
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3035
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3036
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
3037
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
3038
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
3039
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
3040
      }
3041
    }
3042
  } else if (block_size == 32) {
3043
    // unrolling 4 times
3044
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3045
      float* op = &out[rangeIndex * block_size];
3046
      __m256 vop0 = _mm256_setzero_ps();
3047
      __m256 vop8 = _mm256_setzero_ps();
3048
      __m256 vop16 = _mm256_setzero_ps();
3049
      __m256 vop24 = _mm256_setzero_ps();
3050
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
3051
        return false;
3052
      }
3053
      int64_t end_offset = offsets[rangeIndex + 1];
3054
      int64_t length = end_offset - offsets[rangeIndex];
3055
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3056
           ++dataInd) {
3057
        const int64_t idx = indices[dataInd];
3058
        if (idx < 0 || idx >= data_size) {
3059
          return false;
3060
        }
3061
        float wgt = 1.f;
3062
        if (weights) {
3063
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3064
        }
3065
        __m256 vwgt = _mm256_set1_ps(wgt);
3066
        const at::BFloat16* ip = &input[idx * fused_block_size];
3067
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3068
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3069
            ? (dataInd + prefdist_T0)
3070
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3071
            : dataInd;
3072
        const int64_t idx_pref_T0 = indices[next_T0];
3073
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3074
          return false;
3075
        }
3076
        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3077
        vop0 = _mm256_fmadd_ps(
3078
            vwgt,
3079
            _mm256_castsi256_ps(_mm256_slli_epi32(
3080
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
3081
                    reinterpret_cast<const __m128i*>(ip + (0)))),
3082
                16)),
3083
            vop0);
3084
        _mm_prefetch(
3085
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3086
        vop8 = _mm256_fmadd_ps(
3087
            vwgt,
3088
            _mm256_castsi256_ps(_mm256_slli_epi32(
3089
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
3090
                    reinterpret_cast<const __m128i*>(ip + (8)))),
3091
                16)),
3092
            vop8);
3093
        // skip unnecessary prefetch of (&ip_next_T0[8])
3094
        vop16 = _mm256_fmadd_ps(
3095
            vwgt,
3096
            _mm256_castsi256_ps(_mm256_slli_epi32(
3097
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
3098
                    reinterpret_cast<const __m128i*>(ip + (16)))),
3099
                16)),
3100
            vop16);
3101
        // skip unnecessary prefetch of (&ip_next_T0[16])
3102
        vop24 = _mm256_fmadd_ps(
3103
            vwgt,
3104
            _mm256_castsi256_ps(_mm256_slli_epi32(
3105
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
3106
                    reinterpret_cast<const __m128i*>(ip + (24)))),
3107
                16)),
3108
            vop24);
3109
        // skip unnecessary prefetch of (&ip_next_T0[24])
3110
      }
3111
      if (!normalize_by_lengths || length == 0) {
3112
        _mm256_storeu_ps(&op[0], vop0);
3113
        _mm256_storeu_ps(&op[8], vop8);
3114
        _mm256_storeu_ps(&op[16], vop16);
3115
        _mm256_storeu_ps(&op[24], vop24);
3116
      } else {
3117
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3118
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3119
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3120
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3121
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3122
      }
3123
    }
3124
  } else if (block_size == 16) {
3125
    // unrolling 2 times
3126
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3127
      float* op = &out[rangeIndex * block_size];
3128
      __m256 vop0 = _mm256_setzero_ps();
3129
      __m256 vop8 = _mm256_setzero_ps();
3130
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
3131
        return false;
3132
      }
3133
      int64_t end_offset = offsets[rangeIndex + 1];
3134
      int64_t length = end_offset - offsets[rangeIndex];
3135
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3136
           ++dataInd) {
3137
        const int64_t idx = indices[dataInd];
3138
        if (idx < 0 || idx >= data_size) {
3139
          return false;
3140
        }
3141
        float wgt = 1.f;
3142
        if (weights) {
3143
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3144
        }
3145
        __m256 vwgt = _mm256_set1_ps(wgt);
3146
        const at::BFloat16* ip = &input[idx * fused_block_size];
3147
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3148
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3149
            ? (dataInd + prefdist_T0)
3150
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3151
            : dataInd;
3152
        const int64_t idx_pref_T0 = indices[next_T0];
3153
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3154
          return false;
3155
        }
3156
        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3157
        vop0 = _mm256_fmadd_ps(
3158
            vwgt,
3159
            _mm256_castsi256_ps(_mm256_slli_epi32(
3160
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
3161
                    reinterpret_cast<const __m128i*>(ip + (0)))),
3162
                16)),
3163
            vop0);
3164
        _mm_prefetch(
3165
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3166
        vop8 = _mm256_fmadd_ps(
3167
            vwgt,
3168
            _mm256_castsi256_ps(_mm256_slli_epi32(
3169
                _mm256_cvtepu16_epi32(_mm_loadu_si128(
3170
                    reinterpret_cast<const __m128i*>(ip + (8)))),
3171
                16)),
3172
            vop8);
3173
        // skip unnecessary prefetch of (&ip_next_T0[8])
3174
      }
3175
      if (!normalize_by_lengths || length == 0) {
3176
        _mm256_storeu_ps(&op[0], vop0);
3177
        _mm256_storeu_ps(&op[8], vop8);
3178
      } else {
3179
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3180
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3181
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3182
      }
3183
    }
3184
  } else {
3185
    // generic code
3186
    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
3187
    alignas(64) at::BFloat16 vtmp1[8] = {0};
3188
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3189
      float* op = &out[rangeIndex * block_size];
3190
      int64_t j = 0;
3191
      for (; j + 8 <= block_size; j += 8) {
3192
        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
3193
      }
3194
      for (; j < block_size; j++) {
3195
        op[j] = 0.0f;
3196
      }
3197
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
3198
        return false;
3199
      }
3200
      int64_t end_offset = offsets[rangeIndex + 1];
3201
      int64_t length = end_offset - offsets[rangeIndex];
3202
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3203
           ++dataInd) {
3204
        const int64_t idx = indices[dataInd];
3205
        if (idx < 0 || idx >= data_size) {
3206
          return false;
3207
        }
3208
        float wgt = 1.f;
3209
        if (weights) {
3210
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3211
        }
3212
        __m256 vwgt = _mm256_set1_ps(wgt);
3213
        const at::BFloat16* ip = &input[idx * fused_block_size];
3214
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3215
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3216
            ? (dataInd + prefdist_T0)
3217
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3218
            : dataInd;
3219
        const int64_t idx_pref_T0 = indices[next_T0];
3220
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3221
          return false;
3222
        }
3223
        const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3224
        j = 0;
3225
        for (; j + 8 <= block_size; j += 8) {
3226
          _mm256_storeu_ps(
3227
              &op[j],
3228
              _mm256_fmadd_ps(
3229
                  vwgt,
3230
                  _mm256_castsi256_ps(_mm256_slli_epi32(
3231
                      _mm256_cvtepu16_epi32(_mm_loadu_si128(
3232
                          reinterpret_cast<const __m128i*>(&ip[j]))),
3233
                      16)),
3234
                  _mm256_loadu_ps(&op[j])));
3235
          _mm_prefetch(
3236
              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
3237
        }
3238
        for (; j < block_size; j++) {
3239
          vtmp1[0] = ip[j];
3240
          __m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(
3241
              _mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),
3242
              16));
3243
          op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);
3244
        }
3245
      }
3246
      if (normalize_by_lengths && length) {
3247
        float len_inv = 1.0f / length;
3248
        __m256 vlen_inv = _mm256_set1_ps(len_inv);
3249
        j = 0;
3250
        for (; j + 8 <= block_size; j += 8) {
3251
          _mm256_storeu_ps(
3252
              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
3253
        }
3254
        for (; j < block_size; j++) {
3255
          op[j] = len_inv * op[j];
3256
        }
3257
      }
3258
    }
3259
  }
3260
  return dataInd == index_size;
3261
}
3262
bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__avx2_fma(
3263
    const int64_t block_size,
3264
    const int64_t output_size,
3265
    const int64_t index_size,
3266
    const int64_t data_size,
3267
    const at::BFloat16* input,
3268
    const int64_t* indices,
3269
    const int64_t* offsets,
3270
    const float* weights,
3271
    const float* scale_bias,
3272
    bool normalize_by_lengths,
3273
    float* out) {
3274
  return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<false>(
3275
      block_size,
3276
      output_size,
3277
      index_size,
3278
      data_size,
3279
      input,
3280
      indices,
3281
      offsets,
3282
      weights,
3283
      scale_bias,
3284
      normalize_by_lengths,
3285
      out);
3286
}
3287
bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__avx2_fma(
3288
    const int64_t block_size,
3289
    const int64_t output_size,
3290
    const int64_t index_size,
3291
    const int64_t data_size,
3292
    const at::BFloat16* input,
3293
    const int64_t* indices,
3294
    const int64_t* offsets,
3295
    const float* weights,
3296
    const float* scale_bias,
3297
    bool normalize_by_lengths,
3298
    float* out) {
3299
  return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<true>(
3300
      block_size,
3301
      output_size,
3302
      index_size,
3303
      data_size,
3304
      input,
3305
      indices,
3306
      offsets,
3307
      weights,
3308
      scale_bias,
3309
      normalize_by_lengths,
3310
      out);
3311
}
3312

3313
template <bool IS_WEIGHT_POSITIONAL>
3314
static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(
3315
    const int64_t block_size,
3316
    const int64_t output_size,
3317
    const int64_t index_size,
3318
    const int64_t data_size,
3319
    const uint8_t* input,
3320
    const int* indices,
3321
    const int* offsets,
3322
    const float* weights,
3323
    const float* scale_bias,
3324
    bool normalize_by_lengths,
3325
    float* out) {
3326
  const int prefdist_T0 = 16;
3327
  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3328
  const int fused_block_size = block_size + 0;
3329
  int64_t dataInd = 0;
3330
  if (block_size == 128) {
3331
    // unrolling 16 times
3332
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3333
      float* op = &out[rangeIndex * block_size];
3334
      __m256 vop0 = _mm256_setzero_ps();
3335
      __m256 vop8 = _mm256_setzero_ps();
3336
      __m256 vop16 = _mm256_setzero_ps();
3337
      __m256 vop24 = _mm256_setzero_ps();
3338
      __m256 vop32 = _mm256_setzero_ps();
3339
      __m256 vop40 = _mm256_setzero_ps();
3340
      __m256 vop48 = _mm256_setzero_ps();
3341
      __m256 vop56 = _mm256_setzero_ps();
3342
      __m256 vop64 = _mm256_setzero_ps();
3343
      __m256 vop72 = _mm256_setzero_ps();
3344
      __m256 vop80 = _mm256_setzero_ps();
3345
      __m256 vop88 = _mm256_setzero_ps();
3346
      __m256 vop96 = _mm256_setzero_ps();
3347
      __m256 vop104 = _mm256_setzero_ps();
3348
      __m256 vop112 = _mm256_setzero_ps();
3349
      __m256 vop120 = _mm256_setzero_ps();
3350
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
3351
        return false;
3352
      }
3353
      int64_t end_offset = offsets[rangeIndex + 1];
3354
      int64_t length = end_offset - offsets[rangeIndex];
3355
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3356
           ++dataInd) {
3357
        const int idx = indices[dataInd];
3358
        if (idx < 0 || idx >= data_size) {
3359
          return false;
3360
        }
3361
        float wgt = 1.f;
3362
        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3363
        float bio;
3364
        if (weights) {
3365
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3366
        }
3367
        bio = wgt * scale_bias[2 * idx + 1];
3368
        wgt = wgt * scale_bias[2 * idx];
3369
        __m256 vbio = _mm256_set1_ps(bio);
3370
        __m256 vwgt = _mm256_set1_ps(wgt);
3371
        const uint8_t* ip = &input[idx * fused_block_size];
3372
        const int next_T0 = (dataInd < index_size - prefdist_T0)
3373
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3374
            ? (dataInd + prefdist_T0)
3375
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3376
            : dataInd;
3377
        const int idx_pref_T0 = indices[next_T0];
3378
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3379
          return false;
3380
        }
3381
        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3382
        vop0 = _mm256_fmadd_ps(
3383
            vwgt,
3384
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3385
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3386
            _mm256_add_ps(vop0, vbio));
3387
        _mm_prefetch(
3388
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3389
        vop8 = _mm256_fmadd_ps(
3390
            vwgt,
3391
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3392
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3393
            _mm256_add_ps(vop8, vbio));
3394
        // skip unnecessary prefetch of (&ip_next_T0[8])
3395
        vop16 = _mm256_fmadd_ps(
3396
            vwgt,
3397
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3398
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3399
            _mm256_add_ps(vop16, vbio));
3400
        // skip unnecessary prefetch of (&ip_next_T0[16])
3401
        vop24 = _mm256_fmadd_ps(
3402
            vwgt,
3403
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3404
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3405
            _mm256_add_ps(vop24, vbio));
3406
        // skip unnecessary prefetch of (&ip_next_T0[24])
3407
        vop32 = _mm256_fmadd_ps(
3408
            vwgt,
3409
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3410
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
3411
            _mm256_add_ps(vop32, vbio));
3412
        // skip unnecessary prefetch of (&ip_next_T0[32])
3413
        vop40 = _mm256_fmadd_ps(
3414
            vwgt,
3415
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3416
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
3417
            _mm256_add_ps(vop40, vbio));
3418
        // skip unnecessary prefetch of (&ip_next_T0[40])
3419
        vop48 = _mm256_fmadd_ps(
3420
            vwgt,
3421
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3422
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
3423
            _mm256_add_ps(vop48, vbio));
3424
        // skip unnecessary prefetch of (&ip_next_T0[48])
3425
        vop56 = _mm256_fmadd_ps(
3426
            vwgt,
3427
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3428
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
3429
            _mm256_add_ps(vop56, vbio));
3430
        // skip unnecessary prefetch of (&ip_next_T0[56])
3431
        vop64 = _mm256_fmadd_ps(
3432
            vwgt,
3433
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3434
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
3435
            _mm256_add_ps(vop64, vbio));
3436
        _mm_prefetch(
3437
            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
3438
        vop72 = _mm256_fmadd_ps(
3439
            vwgt,
3440
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3441
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
3442
            _mm256_add_ps(vop72, vbio));
3443
        // skip unnecessary prefetch of (&ip_next_T0[72])
3444
        vop80 = _mm256_fmadd_ps(
3445
            vwgt,
3446
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3447
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
3448
            _mm256_add_ps(vop80, vbio));
3449
        // skip unnecessary prefetch of (&ip_next_T0[80])
3450
        vop88 = _mm256_fmadd_ps(
3451
            vwgt,
3452
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3453
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
3454
            _mm256_add_ps(vop88, vbio));
3455
        // skip unnecessary prefetch of (&ip_next_T0[88])
3456
        vop96 = _mm256_fmadd_ps(
3457
            vwgt,
3458
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3459
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
3460
            _mm256_add_ps(vop96, vbio));
3461
        // skip unnecessary prefetch of (&ip_next_T0[96])
3462
        vop104 = _mm256_fmadd_ps(
3463
            vwgt,
3464
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3465
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
3466
            _mm256_add_ps(vop104, vbio));
3467
        // skip unnecessary prefetch of (&ip_next_T0[104])
3468
        vop112 = _mm256_fmadd_ps(
3469
            vwgt,
3470
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3471
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
3472
            _mm256_add_ps(vop112, vbio));
3473
        // skip unnecessary prefetch of (&ip_next_T0[112])
3474
        vop120 = _mm256_fmadd_ps(
3475
            vwgt,
3476
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3477
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
3478
            _mm256_add_ps(vop120, vbio));
3479
        // skip unnecessary prefetch of (&ip_next_T0[120])
3480
      }
3481
      if (!normalize_by_lengths || length == 0) {
3482
        _mm256_storeu_ps(&op[0], vop0);
3483
        _mm256_storeu_ps(&op[8], vop8);
3484
        _mm256_storeu_ps(&op[16], vop16);
3485
        _mm256_storeu_ps(&op[24], vop24);
3486
        _mm256_storeu_ps(&op[32], vop32);
3487
        _mm256_storeu_ps(&op[40], vop40);
3488
        _mm256_storeu_ps(&op[48], vop48);
3489
        _mm256_storeu_ps(&op[56], vop56);
3490
        _mm256_storeu_ps(&op[64], vop64);
3491
        _mm256_storeu_ps(&op[72], vop72);
3492
        _mm256_storeu_ps(&op[80], vop80);
3493
        _mm256_storeu_ps(&op[88], vop88);
3494
        _mm256_storeu_ps(&op[96], vop96);
3495
        _mm256_storeu_ps(&op[104], vop104);
3496
        _mm256_storeu_ps(&op[112], vop112);
3497
        _mm256_storeu_ps(&op[120], vop120);
3498
      } else {
3499
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3500
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3501
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3502
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3503
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3504
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
3505
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
3506
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
3507
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
3508
        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
3509
        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
3510
        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
3511
        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
3512
        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
3513
        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
3514
        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
3515
        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
3516
      }
3517
    }
3518
  } else if (block_size == 64) {
3519
    // unrolling 8 times
3520
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3521
      float* op = &out[rangeIndex * block_size];
3522
      __m256 vop0 = _mm256_setzero_ps();
3523
      __m256 vop8 = _mm256_setzero_ps();
3524
      __m256 vop16 = _mm256_setzero_ps();
3525
      __m256 vop24 = _mm256_setzero_ps();
3526
      __m256 vop32 = _mm256_setzero_ps();
3527
      __m256 vop40 = _mm256_setzero_ps();
3528
      __m256 vop48 = _mm256_setzero_ps();
3529
      __m256 vop56 = _mm256_setzero_ps();
3530
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
3531
        return false;
3532
      }
3533
      int64_t end_offset = offsets[rangeIndex + 1];
3534
      int64_t length = end_offset - offsets[rangeIndex];
3535
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3536
           ++dataInd) {
3537
        const int idx = indices[dataInd];
3538
        if (idx < 0 || idx >= data_size) {
3539
          return false;
3540
        }
3541
        float wgt = 1.f;
3542
        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3543
        float bio;
3544
        if (weights) {
3545
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3546
        }
3547
        bio = wgt * scale_bias[2 * idx + 1];
3548
        wgt = wgt * scale_bias[2 * idx];
3549
        __m256 vbio = _mm256_set1_ps(bio);
3550
        __m256 vwgt = _mm256_set1_ps(wgt);
3551
        const uint8_t* ip = &input[idx * fused_block_size];
3552
        const int next_T0 = (dataInd < index_size - prefdist_T0)
3553
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3554
            ? (dataInd + prefdist_T0)
3555
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3556
            : dataInd;
3557
        const int idx_pref_T0 = indices[next_T0];
3558
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3559
          return false;
3560
        }
3561
        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3562
        vop0 = _mm256_fmadd_ps(
3563
            vwgt,
3564
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3565
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3566
            _mm256_add_ps(vop0, vbio));
3567
        _mm_prefetch(
3568
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3569
        vop8 = _mm256_fmadd_ps(
3570
            vwgt,
3571
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3572
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3573
            _mm256_add_ps(vop8, vbio));
3574
        // skip unnecessary prefetch of (&ip_next_T0[8])
3575
        vop16 = _mm256_fmadd_ps(
3576
            vwgt,
3577
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3578
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3579
            _mm256_add_ps(vop16, vbio));
3580
        // skip unnecessary prefetch of (&ip_next_T0[16])
3581
        vop24 = _mm256_fmadd_ps(
3582
            vwgt,
3583
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3584
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3585
            _mm256_add_ps(vop24, vbio));
3586
        // skip unnecessary prefetch of (&ip_next_T0[24])
3587
        vop32 = _mm256_fmadd_ps(
3588
            vwgt,
3589
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3590
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
3591
            _mm256_add_ps(vop32, vbio));
3592
        // skip unnecessary prefetch of (&ip_next_T0[32])
3593
        vop40 = _mm256_fmadd_ps(
3594
            vwgt,
3595
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3596
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
3597
            _mm256_add_ps(vop40, vbio));
3598
        // skip unnecessary prefetch of (&ip_next_T0[40])
3599
        vop48 = _mm256_fmadd_ps(
3600
            vwgt,
3601
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3602
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
3603
            _mm256_add_ps(vop48, vbio));
3604
        // skip unnecessary prefetch of (&ip_next_T0[48])
3605
        vop56 = _mm256_fmadd_ps(
3606
            vwgt,
3607
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3608
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
3609
            _mm256_add_ps(vop56, vbio));
3610
        // skip unnecessary prefetch of (&ip_next_T0[56])
3611
      }
3612
      if (!normalize_by_lengths || length == 0) {
3613
        _mm256_storeu_ps(&op[0], vop0);
3614
        _mm256_storeu_ps(&op[8], vop8);
3615
        _mm256_storeu_ps(&op[16], vop16);
3616
        _mm256_storeu_ps(&op[24], vop24);
3617
        _mm256_storeu_ps(&op[32], vop32);
3618
        _mm256_storeu_ps(&op[40], vop40);
3619
        _mm256_storeu_ps(&op[48], vop48);
3620
        _mm256_storeu_ps(&op[56], vop56);
3621
      } else {
3622
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3623
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3624
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3625
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3626
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3627
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
3628
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
3629
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
3630
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
3631
      }
3632
    }
3633
  } else if (block_size == 32) {
3634
    // unrolling 4 times
3635
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3636
      float* op = &out[rangeIndex * block_size];
3637
      __m256 vop0 = _mm256_setzero_ps();
3638
      __m256 vop8 = _mm256_setzero_ps();
3639
      __m256 vop16 = _mm256_setzero_ps();
3640
      __m256 vop24 = _mm256_setzero_ps();
3641
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
3642
        return false;
3643
      }
3644
      int64_t end_offset = offsets[rangeIndex + 1];
3645
      int64_t length = end_offset - offsets[rangeIndex];
3646
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3647
           ++dataInd) {
3648
        const int idx = indices[dataInd];
3649
        if (idx < 0 || idx >= data_size) {
3650
          return false;
3651
        }
3652
        float wgt = 1.f;
3653
        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3654
        float bio;
3655
        if (weights) {
3656
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3657
        }
3658
        bio = wgt * scale_bias[2 * idx + 1];
3659
        wgt = wgt * scale_bias[2 * idx];
3660
        __m256 vbio = _mm256_set1_ps(bio);
3661
        __m256 vwgt = _mm256_set1_ps(wgt);
3662
        const uint8_t* ip = &input[idx * fused_block_size];
3663
        const int next_T0 = (dataInd < index_size - prefdist_T0)
3664
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3665
            ? (dataInd + prefdist_T0)
3666
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3667
            : dataInd;
3668
        const int idx_pref_T0 = indices[next_T0];
3669
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3670
          return false;
3671
        }
3672
        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3673
        vop0 = _mm256_fmadd_ps(
3674
            vwgt,
3675
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3676
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3677
            _mm256_add_ps(vop0, vbio));
3678
        _mm_prefetch(
3679
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3680
        vop8 = _mm256_fmadd_ps(
3681
            vwgt,
3682
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3683
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3684
            _mm256_add_ps(vop8, vbio));
3685
        // skip unnecessary prefetch of (&ip_next_T0[8])
3686
        vop16 = _mm256_fmadd_ps(
3687
            vwgt,
3688
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3689
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3690
            _mm256_add_ps(vop16, vbio));
3691
        // skip unnecessary prefetch of (&ip_next_T0[16])
3692
        vop24 = _mm256_fmadd_ps(
3693
            vwgt,
3694
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3695
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3696
            _mm256_add_ps(vop24, vbio));
3697
        // skip unnecessary prefetch of (&ip_next_T0[24])
3698
      }
3699
      if (!normalize_by_lengths || length == 0) {
3700
        _mm256_storeu_ps(&op[0], vop0);
3701
        _mm256_storeu_ps(&op[8], vop8);
3702
        _mm256_storeu_ps(&op[16], vop16);
3703
        _mm256_storeu_ps(&op[24], vop24);
3704
      } else {
3705
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3706
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3707
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3708
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
3709
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
3710
      }
3711
    }
3712
  } else if (block_size == 16) {
3713
    // unrolling 2 times
3714
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3715
      float* op = &out[rangeIndex * block_size];
3716
      __m256 vop0 = _mm256_setzero_ps();
3717
      __m256 vop8 = _mm256_setzero_ps();
3718
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
3719
        return false;
3720
      }
3721
      int64_t end_offset = offsets[rangeIndex + 1];
3722
      int64_t length = end_offset - offsets[rangeIndex];
3723
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3724
           ++dataInd) {
3725
        const int idx = indices[dataInd];
3726
        if (idx < 0 || idx >= data_size) {
3727
          return false;
3728
        }
3729
        float wgt = 1.f;
3730
        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3731
        float bio;
3732
        if (weights) {
3733
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3734
        }
3735
        bio = wgt * scale_bias[2 * idx + 1];
3736
        wgt = wgt * scale_bias[2 * idx];
3737
        __m256 vbio = _mm256_set1_ps(bio);
3738
        __m256 vwgt = _mm256_set1_ps(wgt);
3739
        const uint8_t* ip = &input[idx * fused_block_size];
3740
        const int next_T0 = (dataInd < index_size - prefdist_T0)
3741
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3742
            ? (dataInd + prefdist_T0)
3743
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3744
            : dataInd;
3745
        const int idx_pref_T0 = indices[next_T0];
3746
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3747
          return false;
3748
        }
3749
        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3750
        vop0 = _mm256_fmadd_ps(
3751
            vwgt,
3752
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3753
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3754
            _mm256_add_ps(vop0, vbio));
3755
        _mm_prefetch(
3756
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3757
        vop8 = _mm256_fmadd_ps(
3758
            vwgt,
3759
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3760
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3761
            _mm256_add_ps(vop8, vbio));
3762
        // skip unnecessary prefetch of (&ip_next_T0[8])
3763
      }
3764
      if (!normalize_by_lengths || length == 0) {
3765
        _mm256_storeu_ps(&op[0], vop0);
3766
        _mm256_storeu_ps(&op[8], vop8);
3767
      } else {
3768
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
3769
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
3770
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
3771
      }
3772
    }
3773
  } else {
3774
    // generic code
3775
    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
3776
    for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3777
      float* op = &out[rangeIndex * block_size];
3778
      int64_t j = 0;
3779
      for (; j + 8 <= block_size; j += 8) {
3780
        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
3781
      }
3782
      for (; j < block_size; j++) {
3783
        op[j] = 0.0f;
3784
      }
3785
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
3786
        return false;
3787
      }
3788
      int64_t end_offset = offsets[rangeIndex + 1];
3789
      int64_t length = end_offset - offsets[rangeIndex];
3790
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3791
           ++dataInd) {
3792
        const int idx = indices[dataInd];
3793
        if (idx < 0 || idx >= data_size) {
3794
          return false;
3795
        }
3796
        float wgt = 1.f;
3797
        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3798
        float bio;
3799
        if (weights) {
3800
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3801
        }
3802
        bio = wgt * scale_bias[2 * idx + 1];
3803
        wgt = wgt * scale_bias[2 * idx];
3804
        __m256 vbio = _mm256_set1_ps(bio);
3805
        __m256 vwgt = _mm256_set1_ps(wgt);
3806
        const uint8_t* ip = &input[idx * fused_block_size];
3807
        const int next_T0 = (dataInd < index_size - prefdist_T0)
3808
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3809
            ? (dataInd + prefdist_T0)
3810
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3811
            : dataInd;
3812
        const int idx_pref_T0 = indices[next_T0];
3813
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3814
          return false;
3815
        }
3816
        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3817
        j = 0;
3818
        for (; j + 8 <= block_size; j += 8) {
3819
          _mm256_storeu_ps(
3820
              &op[j],
3821
              _mm256_fmadd_ps(
3822
                  vwgt,
3823
                  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
3824
                      reinterpret_cast<const __m128i*>(&ip[j])))),
3825
                  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
3826
          _mm_prefetch(
3827
              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
3828
        }
3829
        for (; j < block_size; j++) {
3830
          op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);
3831
        }
3832
      }
3833
      if (normalize_by_lengths && length) {
3834
        float len_inv = 1.0f / length;
3835
        __m256 vlen_inv = _mm256_set1_ps(len_inv);
3836
        j = 0;
3837
        for (; j + 8 <= block_size; j += 8) {
3838
          _mm256_storeu_ps(
3839
              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
3840
        }
3841
        for (; j < block_size; j++) {
3842
          op[j] = len_inv * op[j];
3843
        }
3844
      }
3845
    }
3846
  }
3847
  return dataInd == index_size;
3848
}
3849
bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma(
3850
    const int64_t block_size,
3851
    const int64_t output_size,
3852
    const int64_t index_size,
3853
    const int64_t data_size,
3854
    const uint8_t* input,
3855
    const int* indices,
3856
    const int* offsets,
3857
    const float* weights,
3858
    const float* scale_bias,
3859
    bool normalize_by_lengths,
3860
    float* out) {
3861
  return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma<false>(
3862
      block_size,
3863
      output_size,
3864
      index_size,
3865
      data_size,
3866
      input,
3867
      indices,
3868
      offsets,
3869
      weights,
3870
      scale_bias,
3871
      normalize_by_lengths,
3872
      out);
3873
}
3874
bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma(
3875
    const int64_t block_size,
3876
    const int64_t output_size,
3877
    const int64_t index_size,
3878
    const int64_t data_size,
3879
    const uint8_t* input,
3880
    const int* indices,
3881
    const int* offsets,
3882
    const float* weights,
3883
    const float* scale_bias,
3884
    bool normalize_by_lengths,
3885
    float* out) {
3886
  return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma<true>(
3887
      block_size,
3888
      output_size,
3889
      index_size,
3890
      data_size,
3891
      input,
3892
      indices,
3893
      offsets,
3894
      weights,
3895
      scale_bias,
3896
      normalize_by_lengths,
3897
      out);
3898
}
3899

3900
template <bool IS_WEIGHT_POSITIONAL>
3901
static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma(
3902
    const int64_t block_size,
3903
    const int64_t output_size,
3904
    const int64_t index_size,
3905
    const int64_t data_size,
3906
    const uint8_t* input,
3907
    const int64_t* indices,
3908
    const int64_t* offsets,
3909
    const float* weights,
3910
    const float* scale_bias,
3911
    bool normalize_by_lengths,
3912
    float* out) {
3913
  const int64_t prefdist_T0 = 16;
3914
  // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3915
  const int64_t fused_block_size = block_size + 0;
3916
  int64_t dataInd = 0;
3917
  if (block_size == 128) {
3918
    // unrolling 16 times
3919
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
3920
      float* op = &out[rangeIndex * block_size];
3921
      __m256 vop0 = _mm256_setzero_ps();
3922
      __m256 vop8 = _mm256_setzero_ps();
3923
      __m256 vop16 = _mm256_setzero_ps();
3924
      __m256 vop24 = _mm256_setzero_ps();
3925
      __m256 vop32 = _mm256_setzero_ps();
3926
      __m256 vop40 = _mm256_setzero_ps();
3927
      __m256 vop48 = _mm256_setzero_ps();
3928
      __m256 vop56 = _mm256_setzero_ps();
3929
      __m256 vop64 = _mm256_setzero_ps();
3930
      __m256 vop72 = _mm256_setzero_ps();
3931
      __m256 vop80 = _mm256_setzero_ps();
3932
      __m256 vop88 = _mm256_setzero_ps();
3933
      __m256 vop96 = _mm256_setzero_ps();
3934
      __m256 vop104 = _mm256_setzero_ps();
3935
      __m256 vop112 = _mm256_setzero_ps();
3936
      __m256 vop120 = _mm256_setzero_ps();
3937
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
3938
        return false;
3939
      }
3940
      int64_t end_offset = offsets[rangeIndex + 1];
3941
      int64_t length = end_offset - offsets[rangeIndex];
3942
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
3943
           ++dataInd) {
3944
        const int64_t idx = indices[dataInd];
3945
        if (idx < 0 || idx >= data_size) {
3946
          return false;
3947
        }
3948
        float wgt = 1.f;
3949
        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
3950
        float bio;
3951
        if (weights) {
3952
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
3953
        }
3954
        bio = wgt * scale_bias[2 * idx + 1];
3955
        wgt = wgt * scale_bias[2 * idx];
3956
        __m256 vbio = _mm256_set1_ps(bio);
3957
        __m256 vwgt = _mm256_set1_ps(wgt);
3958
        const uint8_t* ip = &input[idx * fused_block_size];
3959
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
3960
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3961
            ? (dataInd + prefdist_T0)
3962
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
3963
            : dataInd;
3964
        const int64_t idx_pref_T0 = indices[next_T0];
3965
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
3966
          return false;
3967
        }
3968
        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
3969
        vop0 = _mm256_fmadd_ps(
3970
            vwgt,
3971
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3972
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
3973
            _mm256_add_ps(vop0, vbio));
3974
        _mm_prefetch(
3975
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
3976
        vop8 = _mm256_fmadd_ps(
3977
            vwgt,
3978
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3979
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
3980
            _mm256_add_ps(vop8, vbio));
3981
        // skip unnecessary prefetch of (&ip_next_T0[8])
3982
        vop16 = _mm256_fmadd_ps(
3983
            vwgt,
3984
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3985
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
3986
            _mm256_add_ps(vop16, vbio));
3987
        // skip unnecessary prefetch of (&ip_next_T0[16])
3988
        vop24 = _mm256_fmadd_ps(
3989
            vwgt,
3990
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3991
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
3992
            _mm256_add_ps(vop24, vbio));
3993
        // skip unnecessary prefetch of (&ip_next_T0[24])
3994
        vop32 = _mm256_fmadd_ps(
3995
            vwgt,
3996
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
3997
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
3998
            _mm256_add_ps(vop32, vbio));
3999
        // skip unnecessary prefetch of (&ip_next_T0[32])
4000
        vop40 = _mm256_fmadd_ps(
4001
            vwgt,
4002
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4003
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
4004
            _mm256_add_ps(vop40, vbio));
4005
        // skip unnecessary prefetch of (&ip_next_T0[40])
4006
        vop48 = _mm256_fmadd_ps(
4007
            vwgt,
4008
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4009
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
4010
            _mm256_add_ps(vop48, vbio));
4011
        // skip unnecessary prefetch of (&ip_next_T0[48])
4012
        vop56 = _mm256_fmadd_ps(
4013
            vwgt,
4014
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4015
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
4016
            _mm256_add_ps(vop56, vbio));
4017
        // skip unnecessary prefetch of (&ip_next_T0[56])
4018
        vop64 = _mm256_fmadd_ps(
4019
            vwgt,
4020
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4021
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),
4022
            _mm256_add_ps(vop64, vbio));
4023
        _mm_prefetch(
4024
            reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);
4025
        vop72 = _mm256_fmadd_ps(
4026
            vwgt,
4027
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4028
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),
4029
            _mm256_add_ps(vop72, vbio));
4030
        // skip unnecessary prefetch of (&ip_next_T0[72])
4031
        vop80 = _mm256_fmadd_ps(
4032
            vwgt,
4033
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4034
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),
4035
            _mm256_add_ps(vop80, vbio));
4036
        // skip unnecessary prefetch of (&ip_next_T0[80])
4037
        vop88 = _mm256_fmadd_ps(
4038
            vwgt,
4039
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4040
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),
4041
            _mm256_add_ps(vop88, vbio));
4042
        // skip unnecessary prefetch of (&ip_next_T0[88])
4043
        vop96 = _mm256_fmadd_ps(
4044
            vwgt,
4045
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4046
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),
4047
            _mm256_add_ps(vop96, vbio));
4048
        // skip unnecessary prefetch of (&ip_next_T0[96])
4049
        vop104 = _mm256_fmadd_ps(
4050
            vwgt,
4051
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4052
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),
4053
            _mm256_add_ps(vop104, vbio));
4054
        // skip unnecessary prefetch of (&ip_next_T0[104])
4055
        vop112 = _mm256_fmadd_ps(
4056
            vwgt,
4057
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4058
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),
4059
            _mm256_add_ps(vop112, vbio));
4060
        // skip unnecessary prefetch of (&ip_next_T0[112])
4061
        vop120 = _mm256_fmadd_ps(
4062
            vwgt,
4063
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4064
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),
4065
            _mm256_add_ps(vop120, vbio));
4066
        // skip unnecessary prefetch of (&ip_next_T0[120])
4067
      }
4068
      if (!normalize_by_lengths || length == 0) {
4069
        _mm256_storeu_ps(&op[0], vop0);
4070
        _mm256_storeu_ps(&op[8], vop8);
4071
        _mm256_storeu_ps(&op[16], vop16);
4072
        _mm256_storeu_ps(&op[24], vop24);
4073
        _mm256_storeu_ps(&op[32], vop32);
4074
        _mm256_storeu_ps(&op[40], vop40);
4075
        _mm256_storeu_ps(&op[48], vop48);
4076
        _mm256_storeu_ps(&op[56], vop56);
4077
        _mm256_storeu_ps(&op[64], vop64);
4078
        _mm256_storeu_ps(&op[72], vop72);
4079
        _mm256_storeu_ps(&op[80], vop80);
4080
        _mm256_storeu_ps(&op[88], vop88);
4081
        _mm256_storeu_ps(&op[96], vop96);
4082
        _mm256_storeu_ps(&op[104], vop104);
4083
        _mm256_storeu_ps(&op[112], vop112);
4084
        _mm256_storeu_ps(&op[120], vop120);
4085
      } else {
4086
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4087
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4088
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4089
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
4090
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
4091
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
4092
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
4093
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
4094
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
4095
        _mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));
4096
        _mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));
4097
        _mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));
4098
        _mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));
4099
        _mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));
4100
        _mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));
4101
        _mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));
4102
        _mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));
4103
      }
4104
    }
4105
  } else if (block_size == 64) {
4106
    // unrolling 8 times
4107
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4108
      float* op = &out[rangeIndex * block_size];
4109
      __m256 vop0 = _mm256_setzero_ps();
4110
      __m256 vop8 = _mm256_setzero_ps();
4111
      __m256 vop16 = _mm256_setzero_ps();
4112
      __m256 vop24 = _mm256_setzero_ps();
4113
      __m256 vop32 = _mm256_setzero_ps();
4114
      __m256 vop40 = _mm256_setzero_ps();
4115
      __m256 vop48 = _mm256_setzero_ps();
4116
      __m256 vop56 = _mm256_setzero_ps();
4117
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
4118
        return false;
4119
      }
4120
      int64_t end_offset = offsets[rangeIndex + 1];
4121
      int64_t length = end_offset - offsets[rangeIndex];
4122
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4123
           ++dataInd) {
4124
        const int64_t idx = indices[dataInd];
4125
        if (idx < 0 || idx >= data_size) {
4126
          return false;
4127
        }
4128
        float wgt = 1.f;
4129
        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4130
        float bio;
4131
        if (weights) {
4132
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4133
        }
4134
        bio = wgt * scale_bias[2 * idx + 1];
4135
        wgt = wgt * scale_bias[2 * idx];
4136
        __m256 vbio = _mm256_set1_ps(bio);
4137
        __m256 vwgt = _mm256_set1_ps(wgt);
4138
        const uint8_t* ip = &input[idx * fused_block_size];
4139
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4140
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4141
            ? (dataInd + prefdist_T0)
4142
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4143
            : dataInd;
4144
        const int64_t idx_pref_T0 = indices[next_T0];
4145
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4146
          return false;
4147
        }
4148
        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4149
        vop0 = _mm256_fmadd_ps(
4150
            vwgt,
4151
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4152
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
4153
            _mm256_add_ps(vop0, vbio));
4154
        _mm_prefetch(
4155
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
4156
        vop8 = _mm256_fmadd_ps(
4157
            vwgt,
4158
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4159
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
4160
            _mm256_add_ps(vop8, vbio));
4161
        // skip unnecessary prefetch of (&ip_next_T0[8])
4162
        vop16 = _mm256_fmadd_ps(
4163
            vwgt,
4164
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4165
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
4166
            _mm256_add_ps(vop16, vbio));
4167
        // skip unnecessary prefetch of (&ip_next_T0[16])
4168
        vop24 = _mm256_fmadd_ps(
4169
            vwgt,
4170
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4171
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
4172
            _mm256_add_ps(vop24, vbio));
4173
        // skip unnecessary prefetch of (&ip_next_T0[24])
4174
        vop32 = _mm256_fmadd_ps(
4175
            vwgt,
4176
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4177
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),
4178
            _mm256_add_ps(vop32, vbio));
4179
        // skip unnecessary prefetch of (&ip_next_T0[32])
4180
        vop40 = _mm256_fmadd_ps(
4181
            vwgt,
4182
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4183
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),
4184
            _mm256_add_ps(vop40, vbio));
4185
        // skip unnecessary prefetch of (&ip_next_T0[40])
4186
        vop48 = _mm256_fmadd_ps(
4187
            vwgt,
4188
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4189
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),
4190
            _mm256_add_ps(vop48, vbio));
4191
        // skip unnecessary prefetch of (&ip_next_T0[48])
4192
        vop56 = _mm256_fmadd_ps(
4193
            vwgt,
4194
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4195
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),
4196
            _mm256_add_ps(vop56, vbio));
4197
        // skip unnecessary prefetch of (&ip_next_T0[56])
4198
      }
4199
      if (!normalize_by_lengths || length == 0) {
4200
        _mm256_storeu_ps(&op[0], vop0);
4201
        _mm256_storeu_ps(&op[8], vop8);
4202
        _mm256_storeu_ps(&op[16], vop16);
4203
        _mm256_storeu_ps(&op[24], vop24);
4204
        _mm256_storeu_ps(&op[32], vop32);
4205
        _mm256_storeu_ps(&op[40], vop40);
4206
        _mm256_storeu_ps(&op[48], vop48);
4207
        _mm256_storeu_ps(&op[56], vop56);
4208
      } else {
4209
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4210
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4211
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4212
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
4213
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
4214
        _mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));
4215
        _mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));
4216
        _mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));
4217
        _mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));
4218
      }
4219
    }
4220
  } else if (block_size == 32) {
4221
    // unrolling 4 times
4222
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4223
      float* op = &out[rangeIndex * block_size];
4224
      __m256 vop0 = _mm256_setzero_ps();
4225
      __m256 vop8 = _mm256_setzero_ps();
4226
      __m256 vop16 = _mm256_setzero_ps();
4227
      __m256 vop24 = _mm256_setzero_ps();
4228
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
4229
        return false;
4230
      }
4231
      int64_t end_offset = offsets[rangeIndex + 1];
4232
      int64_t length = end_offset - offsets[rangeIndex];
4233
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4234
           ++dataInd) {
4235
        const int64_t idx = indices[dataInd];
4236
        if (idx < 0 || idx >= data_size) {
4237
          return false;
4238
        }
4239
        float wgt = 1.f;
4240
        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4241
        float bio;
4242
        if (weights) {
4243
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4244
        }
4245
        bio = wgt * scale_bias[2 * idx + 1];
4246
        wgt = wgt * scale_bias[2 * idx];
4247
        __m256 vbio = _mm256_set1_ps(bio);
4248
        __m256 vwgt = _mm256_set1_ps(wgt);
4249
        const uint8_t* ip = &input[idx * fused_block_size];
4250
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4251
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4252
            ? (dataInd + prefdist_T0)
4253
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4254
            : dataInd;
4255
        const int64_t idx_pref_T0 = indices[next_T0];
4256
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4257
          return false;
4258
        }
4259
        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4260
        vop0 = _mm256_fmadd_ps(
4261
            vwgt,
4262
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4263
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
4264
            _mm256_add_ps(vop0, vbio));
4265
        _mm_prefetch(
4266
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
4267
        vop8 = _mm256_fmadd_ps(
4268
            vwgt,
4269
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4270
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
4271
            _mm256_add_ps(vop8, vbio));
4272
        // skip unnecessary prefetch of (&ip_next_T0[8])
4273
        vop16 = _mm256_fmadd_ps(
4274
            vwgt,
4275
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4276
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),
4277
            _mm256_add_ps(vop16, vbio));
4278
        // skip unnecessary prefetch of (&ip_next_T0[16])
4279
        vop24 = _mm256_fmadd_ps(
4280
            vwgt,
4281
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4282
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),
4283
            _mm256_add_ps(vop24, vbio));
4284
        // skip unnecessary prefetch of (&ip_next_T0[24])
4285
      }
4286
      if (!normalize_by_lengths || length == 0) {
4287
        _mm256_storeu_ps(&op[0], vop0);
4288
        _mm256_storeu_ps(&op[8], vop8);
4289
        _mm256_storeu_ps(&op[16], vop16);
4290
        _mm256_storeu_ps(&op[24], vop24);
4291
      } else {
4292
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4293
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4294
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4295
        _mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));
4296
        _mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));
4297
      }
4298
    }
4299
  } else if (block_size == 16) {
4300
    // unrolling 2 times
4301
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4302
      float* op = &out[rangeIndex * block_size];
4303
      __m256 vop0 = _mm256_setzero_ps();
4304
      __m256 vop8 = _mm256_setzero_ps();
4305
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
4306
        return false;
4307
      }
4308
      int64_t end_offset = offsets[rangeIndex + 1];
4309
      int64_t length = end_offset - offsets[rangeIndex];
4310
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4311
           ++dataInd) {
4312
        const int64_t idx = indices[dataInd];
4313
        if (idx < 0 || idx >= data_size) {
4314
          return false;
4315
        }
4316
        float wgt = 1.f;
4317
        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4318
        float bio;
4319
        if (weights) {
4320
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4321
        }
4322
        bio = wgt * scale_bias[2 * idx + 1];
4323
        wgt = wgt * scale_bias[2 * idx];
4324
        __m256 vbio = _mm256_set1_ps(bio);
4325
        __m256 vwgt = _mm256_set1_ps(wgt);
4326
        const uint8_t* ip = &input[idx * fused_block_size];
4327
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4328
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4329
            ? (dataInd + prefdist_T0)
4330
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4331
            : dataInd;
4332
        const int64_t idx_pref_T0 = indices[next_T0];
4333
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4334
          return false;
4335
        }
4336
        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4337
        vop0 = _mm256_fmadd_ps(
4338
            vwgt,
4339
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4340
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),
4341
            _mm256_add_ps(vop0, vbio));
4342
        _mm_prefetch(
4343
            reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);
4344
        vop8 = _mm256_fmadd_ps(
4345
            vwgt,
4346
            _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(
4347
                _mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),
4348
            _mm256_add_ps(vop8, vbio));
4349
        // skip unnecessary prefetch of (&ip_next_T0[8])
4350
      }
4351
      if (!normalize_by_lengths || length == 0) {
4352
        _mm256_storeu_ps(&op[0], vop0);
4353
        _mm256_storeu_ps(&op[8], vop8);
4354
      } else {
4355
        __m256 vlen_inv = _mm256_set1_ps(1.0f / length);
4356
        _mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));
4357
        _mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));
4358
      }
4359
    }
4360
  } else {
4361
    // generic code
4362
    // NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)
4363
    for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {
4364
      float* op = &out[rangeIndex * block_size];
4365
      int64_t j = 0;
4366
      for (; j + 8 <= block_size; j += 8) {
4367
        _mm256_storeu_ps(op + j, _mm256_setzero_ps());
4368
      }
4369
      for (; j < block_size; j++) {
4370
        op[j] = 0.0f;
4371
      }
4372
      if (dataInd != offsets[rangeIndex] - offsets[0]) {
4373
        return false;
4374
      }
4375
      int64_t end_offset = offsets[rangeIndex + 1];
4376
      int64_t length = end_offset - offsets[rangeIndex];
4377
      for (int64_t start = dataInd; dataInd < end_offset - offsets[0];
4378
           ++dataInd) {
4379
        const int64_t idx = indices[dataInd];
4380
        if (idx < 0 || idx >= data_size) {
4381
          return false;
4382
        }
4383
        float wgt = 1.f;
4384
        // NOLINTNEXTLINE(cppcoreguidelines-init-variables)
4385
        float bio;
4386
        if (weights) {
4387
          wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];
4388
        }
4389
        bio = wgt * scale_bias[2 * idx + 1];
4390
        wgt = wgt * scale_bias[2 * idx];
4391
        __m256 vbio = _mm256_set1_ps(bio);
4392
        __m256 vwgt = _mm256_set1_ps(wgt);
4393
        const uint8_t* ip = &input[idx * fused_block_size];
4394
        const int64_t next_T0 = (dataInd < index_size - prefdist_T0)
4395
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4396
            ? (dataInd + prefdist_T0)
4397
            // NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)
4398
            : dataInd;
4399
        const int64_t idx_pref_T0 = indices[next_T0];
4400
        if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {
4401
          return false;
4402
        }
4403
        const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];
4404
        j = 0;
4405
        for (; j + 8 <= block_size; j += 8) {
4406
          _mm256_storeu_ps(
4407
              &op[j],
4408
              _mm256_fmadd_ps(
4409
                  vwgt,
4410
                  _mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(
4411
                      reinterpret_cast<const __m128i*>(&ip[j])))),
4412
                  _mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));
4413
          _mm_prefetch(
4414
              reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);
4415
        }
4416
        for (; j < block_size; j++) {
4417
          op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);
4418
        }
4419
      }
4420
      if (normalize_by_lengths && length) {
4421
        float len_inv = 1.0f / length;
4422
        __m256 vlen_inv = _mm256_set1_ps(len_inv);
4423
        j = 0;
4424
        for (; j + 8 <= block_size; j += 8) {
4425
          _mm256_storeu_ps(
4426
              &op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));
4427
        }
4428
        for (; j < block_size; j++) {
4429
          op[j] = len_inv * op[j];
4430
        }
4431
      }
4432
    }
4433
  }
4434
  return dataInd == index_size;
4435
}
4436
bool EmbeddingLookupIdx_int64_t_uint8_t_float_false__avx2_fma(
4437
    const int64_t block_size,
4438
    const int64_t output_size,
4439
    const int64_t index_size,
4440
    const int64_t data_size,
4441
    const uint8_t* input,
4442
    const int64_t* indices,
4443
    const int64_t* offsets,
4444
    const float* weights,
4445
    const float* scale_bias,
4446
    bool normalize_by_lengths,
4447
    float* out) {
4448
  return EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma<false>(
4449
      block_size,
4450
      output_size,
4451
      index_size,
4452
      data_size,
4453
      input,
4454
      indices,
4455
      offsets,
4456
      weights,
4457
      scale_bias,
4458
      normalize_by_lengths,
4459
      out);
4460
}
4461
bool EmbeddingLookupIdx_int64_t_uint8_t_float_true__avx2_fma(
4462
    const int64_t block_size,
4463
    const int64_t output_size,
4464
    const int64_t index_size,
4465
    const int64_t data_size,
4466
    const uint8_t* input,
4467
    const int64_t* indices,
4468
    const int64_t* offsets,
4469
    const float* weights,
4470
    const float* scale_bias,
4471
    bool normalize_by_lengths,
4472
    float* out) {
4473
  return EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma<true>(
4474
      block_size,
4475
      output_size,
4476
      index_size,
4477
      data_size,
4478
      input,
4479
      indices,
4480
      offsets,
4481
      weights,
4482
      scale_bias,
4483
      normalize_by_lengths,
4484
      out);
4485
}
4486

4487
} // namespace caffe2
4488

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.