pytorch
4487 строк · 181.0 Кб
1//// --------------------------
2//// ATTENTION:
3//// THIS CODE IS AUTOGENERATED
4//// BY hp_emblookup_codegen.py
5//// DO NOT MODIFY!!!
6//// --------------------------
7
8#include <c10/util/Half.h>9#include <c10/util/BFloat16.h>10#include <immintrin.h>11namespace caffe2 {12
13template <bool IS_WEIGHT_POSITIONAL>14static bool EmbeddingLookupIdx_int32_t_float_float__avx2_fma(15const int64_t block_size,16const int64_t output_size,17const int64_t index_size,18const int64_t data_size,19const float* input,20const int* indices,21const int* offsets,22const float* weights,23const float* scale_bias,24bool normalize_by_lengths,25float* out) {26const int prefdist_T0 = 16;27// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)28const int fused_block_size = block_size + 0;29int64_t dataInd = 0;30if (block_size == 128) {31// unrolling 16 times32for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {33float* op = &out[rangeIndex * block_size];34__m256 vop0 = _mm256_setzero_ps();35__m256 vop8 = _mm256_setzero_ps();36__m256 vop16 = _mm256_setzero_ps();37__m256 vop24 = _mm256_setzero_ps();38__m256 vop32 = _mm256_setzero_ps();39__m256 vop40 = _mm256_setzero_ps();40__m256 vop48 = _mm256_setzero_ps();41__m256 vop56 = _mm256_setzero_ps();42__m256 vop64 = _mm256_setzero_ps();43__m256 vop72 = _mm256_setzero_ps();44__m256 vop80 = _mm256_setzero_ps();45__m256 vop88 = _mm256_setzero_ps();46__m256 vop96 = _mm256_setzero_ps();47__m256 vop104 = _mm256_setzero_ps();48__m256 vop112 = _mm256_setzero_ps();49__m256 vop120 = _mm256_setzero_ps();50if (dataInd != offsets[rangeIndex] - offsets[0]) {51return false;52}53int64_t end_offset = offsets[rangeIndex + 1];54int64_t length = end_offset - offsets[rangeIndex];55for (int64_t start = dataInd; dataInd < end_offset - offsets[0];56++dataInd) {57const int idx = indices[dataInd];58if (idx < 0 || idx >= data_size) {59return false;60}61float wgt = 1.f;62if (weights) {63wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];64}65__m256 vwgt = _mm256_set1_ps(wgt);66const float* ip = &input[idx * fused_block_size];67const int next_T0 = (dataInd < index_size - prefdist_T0)68// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)69? (dataInd + prefdist_T0)70// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)71: dataInd;72const int idx_pref_T0 = indices[next_T0];73if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {74return false;75}76const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];77vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);78_mm_prefetch(79reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);80vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);81// skip unnecessary prefetch of (&ip_next_T0[8])82vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);83_mm_prefetch(84reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);85vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);86// skip unnecessary prefetch of (&ip_next_T0[24])87vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);88_mm_prefetch(89reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);90vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);91// skip unnecessary prefetch of (&ip_next_T0[40])92vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);93_mm_prefetch(94reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);95vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);96// skip unnecessary prefetch of (&ip_next_T0[56])97vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);98_mm_prefetch(99reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);100vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);101// skip unnecessary prefetch of (&ip_next_T0[72])102vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);103_mm_prefetch(104reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);105vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);106// skip unnecessary prefetch of (&ip_next_T0[88])107vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);108_mm_prefetch(109reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);110vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);111// skip unnecessary prefetch of (&ip_next_T0[104])112vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);113_mm_prefetch(114reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);115vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);116// skip unnecessary prefetch of (&ip_next_T0[120])117}118if (!normalize_by_lengths || length == 0) {119_mm256_storeu_ps(&op[0], vop0);120_mm256_storeu_ps(&op[8], vop8);121_mm256_storeu_ps(&op[16], vop16);122_mm256_storeu_ps(&op[24], vop24);123_mm256_storeu_ps(&op[32], vop32);124_mm256_storeu_ps(&op[40], vop40);125_mm256_storeu_ps(&op[48], vop48);126_mm256_storeu_ps(&op[56], vop56);127_mm256_storeu_ps(&op[64], vop64);128_mm256_storeu_ps(&op[72], vop72);129_mm256_storeu_ps(&op[80], vop80);130_mm256_storeu_ps(&op[88], vop88);131_mm256_storeu_ps(&op[96], vop96);132_mm256_storeu_ps(&op[104], vop104);133_mm256_storeu_ps(&op[112], vop112);134_mm256_storeu_ps(&op[120], vop120);135} else {136__m256 vlen_inv = _mm256_set1_ps(1.0f / length);137_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));138_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));139_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));140_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));141_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));142_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));143_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));144_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));145_mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));146_mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));147_mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));148_mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));149_mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));150_mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));151_mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));152_mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));153}154}155} else if (block_size == 64) {156// unrolling 8 times157for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {158float* op = &out[rangeIndex * block_size];159__m256 vop0 = _mm256_setzero_ps();160__m256 vop8 = _mm256_setzero_ps();161__m256 vop16 = _mm256_setzero_ps();162__m256 vop24 = _mm256_setzero_ps();163__m256 vop32 = _mm256_setzero_ps();164__m256 vop40 = _mm256_setzero_ps();165__m256 vop48 = _mm256_setzero_ps();166__m256 vop56 = _mm256_setzero_ps();167if (dataInd != offsets[rangeIndex] - offsets[0]) {168return false;169}170int64_t end_offset = offsets[rangeIndex + 1];171int64_t length = end_offset - offsets[rangeIndex];172for (int64_t start = dataInd; dataInd < end_offset - offsets[0];173++dataInd) {174const int idx = indices[dataInd];175if (idx < 0 || idx >= data_size) {176return false;177}178float wgt = 1.f;179if (weights) {180wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];181}182__m256 vwgt = _mm256_set1_ps(wgt);183const float* ip = &input[idx * fused_block_size];184const int next_T0 = (dataInd < index_size - prefdist_T0)185// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)186? (dataInd + prefdist_T0)187// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)188: dataInd;189const int idx_pref_T0 = indices[next_T0];190if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {191return false;192}193const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];194vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);195_mm_prefetch(196reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);197vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);198// skip unnecessary prefetch of (&ip_next_T0[8])199vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);200_mm_prefetch(201reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);202vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);203// skip unnecessary prefetch of (&ip_next_T0[24])204vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);205_mm_prefetch(206reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);207vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);208// skip unnecessary prefetch of (&ip_next_T0[40])209vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);210_mm_prefetch(211reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);212vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);213// skip unnecessary prefetch of (&ip_next_T0[56])214}215if (!normalize_by_lengths || length == 0) {216_mm256_storeu_ps(&op[0], vop0);217_mm256_storeu_ps(&op[8], vop8);218_mm256_storeu_ps(&op[16], vop16);219_mm256_storeu_ps(&op[24], vop24);220_mm256_storeu_ps(&op[32], vop32);221_mm256_storeu_ps(&op[40], vop40);222_mm256_storeu_ps(&op[48], vop48);223_mm256_storeu_ps(&op[56], vop56);224} else {225__m256 vlen_inv = _mm256_set1_ps(1.0f / length);226_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));227_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));228_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));229_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));230_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));231_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));232_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));233_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));234}235}236} else if (block_size == 32) {237// unrolling 4 times238for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {239float* op = &out[rangeIndex * block_size];240__m256 vop0 = _mm256_setzero_ps();241__m256 vop8 = _mm256_setzero_ps();242__m256 vop16 = _mm256_setzero_ps();243__m256 vop24 = _mm256_setzero_ps();244if (dataInd != offsets[rangeIndex] - offsets[0]) {245return false;246}247int64_t end_offset = offsets[rangeIndex + 1];248int64_t length = end_offset - offsets[rangeIndex];249for (int64_t start = dataInd; dataInd < end_offset - offsets[0];250++dataInd) {251const int idx = indices[dataInd];252if (idx < 0 || idx >= data_size) {253return false;254}255float wgt = 1.f;256if (weights) {257wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];258}259__m256 vwgt = _mm256_set1_ps(wgt);260const float* ip = &input[idx * fused_block_size];261const int next_T0 = (dataInd < index_size - prefdist_T0)262// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)263? (dataInd + prefdist_T0)264// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)265: dataInd;266const int idx_pref_T0 = indices[next_T0];267if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {268return false;269}270const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];271vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);272_mm_prefetch(273reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);274vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);275// skip unnecessary prefetch of (&ip_next_T0[8])276vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);277_mm_prefetch(278reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);279vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);280// skip unnecessary prefetch of (&ip_next_T0[24])281}282if (!normalize_by_lengths || length == 0) {283_mm256_storeu_ps(&op[0], vop0);284_mm256_storeu_ps(&op[8], vop8);285_mm256_storeu_ps(&op[16], vop16);286_mm256_storeu_ps(&op[24], vop24);287} else {288__m256 vlen_inv = _mm256_set1_ps(1.0f / length);289_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));290_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));291_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));292_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));293}294}295} else if (block_size == 16) {296// unrolling 2 times297for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {298float* op = &out[rangeIndex * block_size];299__m256 vop0 = _mm256_setzero_ps();300__m256 vop8 = _mm256_setzero_ps();301if (dataInd != offsets[rangeIndex] - offsets[0]) {302return false;303}304int64_t end_offset = offsets[rangeIndex + 1];305int64_t length = end_offset - offsets[rangeIndex];306for (int64_t start = dataInd; dataInd < end_offset - offsets[0];307++dataInd) {308const int idx = indices[dataInd];309if (idx < 0 || idx >= data_size) {310return false;311}312float wgt = 1.f;313if (weights) {314wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];315}316__m256 vwgt = _mm256_set1_ps(wgt);317const float* ip = &input[idx * fused_block_size];318const int next_T0 = (dataInd < index_size - prefdist_T0)319// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)320? (dataInd + prefdist_T0)321// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)322: dataInd;323const int idx_pref_T0 = indices[next_T0];324if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {325return false;326}327const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];328vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);329_mm_prefetch(330reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);331vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);332// skip unnecessary prefetch of (&ip_next_T0[8])333}334if (!normalize_by_lengths || length == 0) {335_mm256_storeu_ps(&op[0], vop0);336_mm256_storeu_ps(&op[8], vop8);337} else {338__m256 vlen_inv = _mm256_set1_ps(1.0f / length);339_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));340_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));341}342}343} else {344// generic code345// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)346for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {347float* op = &out[rangeIndex * block_size];348int64_t j = 0;349for (; j + 8 <= block_size; j += 8) {350_mm256_storeu_ps(op + j, _mm256_setzero_ps());351}352for (; j < block_size; j++) {353op[j] = 0.0f;354}355if (dataInd != offsets[rangeIndex] - offsets[0]) {356return false;357}358int64_t end_offset = offsets[rangeIndex + 1];359int64_t length = end_offset - offsets[rangeIndex];360for (int64_t start = dataInd; dataInd < end_offset - offsets[0];361++dataInd) {362const int idx = indices[dataInd];363if (idx < 0 || idx >= data_size) {364return false;365}366float wgt = 1.f;367if (weights) {368wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];369}370__m256 vwgt = _mm256_set1_ps(wgt);371const float* ip = &input[idx * fused_block_size];372const int next_T0 = (dataInd < index_size - prefdist_T0)373// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)374? (dataInd + prefdist_T0)375// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)376: dataInd;377const int idx_pref_T0 = indices[next_T0];378if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {379return false;380}381const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];382j = 0;383for (; j + 8 <= block_size; j += 8) {384_mm256_storeu_ps(385&op[j],386_mm256_fmadd_ps(387vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));388_mm_prefetch(389reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);390}391for (; j < block_size; j++) {392op[j] = std::fma(wgt, ip[j], op[j]);393}394}395if (normalize_by_lengths && length) {396float len_inv = 1.0f / length;397__m256 vlen_inv = _mm256_set1_ps(len_inv);398j = 0;399for (; j + 8 <= block_size; j += 8) {400_mm256_storeu_ps(401&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));402}403for (; j < block_size; j++) {404op[j] = len_inv * op[j];405}406}407}408}409return dataInd == index_size;410}
411bool EmbeddingLookupIdx_int32_t_float_float_false__avx2_fma(412const int64_t block_size,413const int64_t output_size,414const int64_t index_size,415const int64_t data_size,416const float* input,417const int* indices,418const int* offsets,419const float* weights,420const float* scale_bias,421bool normalize_by_lengths,422float* out) {423return EmbeddingLookupIdx_int32_t_float_float__avx2_fma<false>(424block_size,425output_size,426index_size,427data_size,428input,429indices,430offsets,431weights,432scale_bias,433normalize_by_lengths,434out);435}
436bool EmbeddingLookupIdx_int32_t_float_float_true__avx2_fma(437const int64_t block_size,438const int64_t output_size,439const int64_t index_size,440const int64_t data_size,441const float* input,442const int* indices,443const int* offsets,444const float* weights,445const float* scale_bias,446bool normalize_by_lengths,447float* out) {448return EmbeddingLookupIdx_int32_t_float_float__avx2_fma<true>(449block_size,450output_size,451index_size,452data_size,453input,454indices,455offsets,456weights,457scale_bias,458normalize_by_lengths,459out);460}
461
462template <bool IS_WEIGHT_POSITIONAL>463static bool EmbeddingLookupIdx_int64_t_float_float__avx2_fma(464const int64_t block_size,465const int64_t output_size,466const int64_t index_size,467const int64_t data_size,468const float* input,469const int64_t* indices,470const int64_t* offsets,471const float* weights,472const float* scale_bias,473bool normalize_by_lengths,474float* out) {475const int64_t prefdist_T0 = 16;476// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)477const int64_t fused_block_size = block_size + 0;478int64_t dataInd = 0;479if (block_size == 128) {480// unrolling 16 times481for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {482float* op = &out[rangeIndex * block_size];483__m256 vop0 = _mm256_setzero_ps();484__m256 vop8 = _mm256_setzero_ps();485__m256 vop16 = _mm256_setzero_ps();486__m256 vop24 = _mm256_setzero_ps();487__m256 vop32 = _mm256_setzero_ps();488__m256 vop40 = _mm256_setzero_ps();489__m256 vop48 = _mm256_setzero_ps();490__m256 vop56 = _mm256_setzero_ps();491__m256 vop64 = _mm256_setzero_ps();492__m256 vop72 = _mm256_setzero_ps();493__m256 vop80 = _mm256_setzero_ps();494__m256 vop88 = _mm256_setzero_ps();495__m256 vop96 = _mm256_setzero_ps();496__m256 vop104 = _mm256_setzero_ps();497__m256 vop112 = _mm256_setzero_ps();498__m256 vop120 = _mm256_setzero_ps();499if (dataInd != offsets[rangeIndex] - offsets[0]) {500return false;501}502int64_t end_offset = offsets[rangeIndex + 1];503int64_t length = end_offset - offsets[rangeIndex];504for (int64_t start = dataInd; dataInd < end_offset - offsets[0];505++dataInd) {506const int64_t idx = indices[dataInd];507if (idx < 0 || idx >= data_size) {508return false;509}510float wgt = 1.f;511if (weights) {512wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];513}514__m256 vwgt = _mm256_set1_ps(wgt);515const float* ip = &input[idx * fused_block_size];516const int64_t next_T0 = (dataInd < index_size - prefdist_T0)517// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)518? (dataInd + prefdist_T0)519// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)520: dataInd;521const int64_t idx_pref_T0 = indices[next_T0];522if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {523return false;524}525const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];526vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);527_mm_prefetch(528reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);529vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);530// skip unnecessary prefetch of (&ip_next_T0[8])531vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);532_mm_prefetch(533reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);534vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);535// skip unnecessary prefetch of (&ip_next_T0[24])536vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);537_mm_prefetch(538reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);539vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);540// skip unnecessary prefetch of (&ip_next_T0[40])541vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);542_mm_prefetch(543reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);544vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);545// skip unnecessary prefetch of (&ip_next_T0[56])546vop64 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (64)), vop64);547_mm_prefetch(548reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);549vop72 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (72)), vop72);550// skip unnecessary prefetch of (&ip_next_T0[72])551vop80 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (80)), vop80);552_mm_prefetch(553reinterpret_cast<const char*>(&ip_next_T0[80]), _MM_HINT_T0);554vop88 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (88)), vop88);555// skip unnecessary prefetch of (&ip_next_T0[88])556vop96 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (96)), vop96);557_mm_prefetch(558reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);559vop104 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (104)), vop104);560// skip unnecessary prefetch of (&ip_next_T0[104])561vop112 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (112)), vop112);562_mm_prefetch(563reinterpret_cast<const char*>(&ip_next_T0[112]), _MM_HINT_T0);564vop120 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (120)), vop120);565// skip unnecessary prefetch of (&ip_next_T0[120])566}567if (!normalize_by_lengths || length == 0) {568_mm256_storeu_ps(&op[0], vop0);569_mm256_storeu_ps(&op[8], vop8);570_mm256_storeu_ps(&op[16], vop16);571_mm256_storeu_ps(&op[24], vop24);572_mm256_storeu_ps(&op[32], vop32);573_mm256_storeu_ps(&op[40], vop40);574_mm256_storeu_ps(&op[48], vop48);575_mm256_storeu_ps(&op[56], vop56);576_mm256_storeu_ps(&op[64], vop64);577_mm256_storeu_ps(&op[72], vop72);578_mm256_storeu_ps(&op[80], vop80);579_mm256_storeu_ps(&op[88], vop88);580_mm256_storeu_ps(&op[96], vop96);581_mm256_storeu_ps(&op[104], vop104);582_mm256_storeu_ps(&op[112], vop112);583_mm256_storeu_ps(&op[120], vop120);584} else {585__m256 vlen_inv = _mm256_set1_ps(1.0f / length);586_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));587_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));588_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));589_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));590_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));591_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));592_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));593_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));594_mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));595_mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));596_mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));597_mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));598_mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));599_mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));600_mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));601_mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));602}603}604} else if (block_size == 64) {605// unrolling 8 times606for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {607float* op = &out[rangeIndex * block_size];608__m256 vop0 = _mm256_setzero_ps();609__m256 vop8 = _mm256_setzero_ps();610__m256 vop16 = _mm256_setzero_ps();611__m256 vop24 = _mm256_setzero_ps();612__m256 vop32 = _mm256_setzero_ps();613__m256 vop40 = _mm256_setzero_ps();614__m256 vop48 = _mm256_setzero_ps();615__m256 vop56 = _mm256_setzero_ps();616if (dataInd != offsets[rangeIndex] - offsets[0]) {617return false;618}619int64_t end_offset = offsets[rangeIndex + 1];620int64_t length = end_offset - offsets[rangeIndex];621for (int64_t start = dataInd; dataInd < end_offset - offsets[0];622++dataInd) {623const int64_t idx = indices[dataInd];624if (idx < 0 || idx >= data_size) {625return false;626}627float wgt = 1.f;628if (weights) {629wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];630}631__m256 vwgt = _mm256_set1_ps(wgt);632const float* ip = &input[idx * fused_block_size];633const int64_t next_T0 = (dataInd < index_size - prefdist_T0)634// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)635? (dataInd + prefdist_T0)636// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)637: dataInd;638const int64_t idx_pref_T0 = indices[next_T0];639if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {640return false;641}642const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];643vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);644_mm_prefetch(645reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);646vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);647// skip unnecessary prefetch of (&ip_next_T0[8])648vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);649_mm_prefetch(650reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);651vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);652// skip unnecessary prefetch of (&ip_next_T0[24])653vop32 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (32)), vop32);654_mm_prefetch(655reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);656vop40 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (40)), vop40);657// skip unnecessary prefetch of (&ip_next_T0[40])658vop48 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (48)), vop48);659_mm_prefetch(660reinterpret_cast<const char*>(&ip_next_T0[48]), _MM_HINT_T0);661vop56 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (56)), vop56);662// skip unnecessary prefetch of (&ip_next_T0[56])663}664if (!normalize_by_lengths || length == 0) {665_mm256_storeu_ps(&op[0], vop0);666_mm256_storeu_ps(&op[8], vop8);667_mm256_storeu_ps(&op[16], vop16);668_mm256_storeu_ps(&op[24], vop24);669_mm256_storeu_ps(&op[32], vop32);670_mm256_storeu_ps(&op[40], vop40);671_mm256_storeu_ps(&op[48], vop48);672_mm256_storeu_ps(&op[56], vop56);673} else {674__m256 vlen_inv = _mm256_set1_ps(1.0f / length);675_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));676_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));677_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));678_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));679_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));680_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));681_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));682_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));683}684}685} else if (block_size == 32) {686// unrolling 4 times687for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {688float* op = &out[rangeIndex * block_size];689__m256 vop0 = _mm256_setzero_ps();690__m256 vop8 = _mm256_setzero_ps();691__m256 vop16 = _mm256_setzero_ps();692__m256 vop24 = _mm256_setzero_ps();693if (dataInd != offsets[rangeIndex] - offsets[0]) {694return false;695}696int64_t end_offset = offsets[rangeIndex + 1];697int64_t length = end_offset - offsets[rangeIndex];698for (int64_t start = dataInd; dataInd < end_offset - offsets[0];699++dataInd) {700const int64_t idx = indices[dataInd];701if (idx < 0 || idx >= data_size) {702return false;703}704float wgt = 1.f;705if (weights) {706wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];707}708__m256 vwgt = _mm256_set1_ps(wgt);709const float* ip = &input[idx * fused_block_size];710const int64_t next_T0 = (dataInd < index_size - prefdist_T0)711// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)712? (dataInd + prefdist_T0)713// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)714: dataInd;715const int64_t idx_pref_T0 = indices[next_T0];716if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {717return false;718}719const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];720vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);721_mm_prefetch(722reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);723vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);724// skip unnecessary prefetch of (&ip_next_T0[8])725vop16 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (16)), vop16);726_mm_prefetch(727reinterpret_cast<const char*>(&ip_next_T0[16]), _MM_HINT_T0);728vop24 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (24)), vop24);729// skip unnecessary prefetch of (&ip_next_T0[24])730}731if (!normalize_by_lengths || length == 0) {732_mm256_storeu_ps(&op[0], vop0);733_mm256_storeu_ps(&op[8], vop8);734_mm256_storeu_ps(&op[16], vop16);735_mm256_storeu_ps(&op[24], vop24);736} else {737__m256 vlen_inv = _mm256_set1_ps(1.0f / length);738_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));739_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));740_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));741_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));742}743}744} else if (block_size == 16) {745// unrolling 2 times746for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {747float* op = &out[rangeIndex * block_size];748__m256 vop0 = _mm256_setzero_ps();749__m256 vop8 = _mm256_setzero_ps();750if (dataInd != offsets[rangeIndex] - offsets[0]) {751return false;752}753int64_t end_offset = offsets[rangeIndex + 1];754int64_t length = end_offset - offsets[rangeIndex];755for (int64_t start = dataInd; dataInd < end_offset - offsets[0];756++dataInd) {757const int64_t idx = indices[dataInd];758if (idx < 0 || idx >= data_size) {759return false;760}761float wgt = 1.f;762if (weights) {763wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];764}765__m256 vwgt = _mm256_set1_ps(wgt);766const float* ip = &input[idx * fused_block_size];767const int64_t next_T0 = (dataInd < index_size - prefdist_T0)768// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)769? (dataInd + prefdist_T0)770// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)771: dataInd;772const int64_t idx_pref_T0 = indices[next_T0];773if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {774return false;775}776const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];777vop0 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (0)), vop0);778_mm_prefetch(779reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);780vop8 = _mm256_fmadd_ps(vwgt, _mm256_loadu_ps(ip + (8)), vop8);781// skip unnecessary prefetch of (&ip_next_T0[8])782}783if (!normalize_by_lengths || length == 0) {784_mm256_storeu_ps(&op[0], vop0);785_mm256_storeu_ps(&op[8], vop8);786} else {787__m256 vlen_inv = _mm256_set1_ps(1.0f / length);788_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));789_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));790}791}792} else {793// generic code794// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)795for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {796float* op = &out[rangeIndex * block_size];797int64_t j = 0;798for (; j + 8 <= block_size; j += 8) {799_mm256_storeu_ps(op + j, _mm256_setzero_ps());800}801for (; j < block_size; j++) {802op[j] = 0.0f;803}804if (dataInd != offsets[rangeIndex] - offsets[0]) {805return false;806}807int64_t end_offset = offsets[rangeIndex + 1];808int64_t length = end_offset - offsets[rangeIndex];809for (int64_t start = dataInd; dataInd < end_offset - offsets[0];810++dataInd) {811const int64_t idx = indices[dataInd];812if (idx < 0 || idx >= data_size) {813return false;814}815float wgt = 1.f;816if (weights) {817wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];818}819__m256 vwgt = _mm256_set1_ps(wgt);820const float* ip = &input[idx * fused_block_size];821const int64_t next_T0 = (dataInd < index_size - prefdist_T0)822// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)823? (dataInd + prefdist_T0)824// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)825: dataInd;826const int64_t idx_pref_T0 = indices[next_T0];827if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {828return false;829}830const float* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];831j = 0;832for (; j + 8 <= block_size; j += 8) {833_mm256_storeu_ps(834&op[j],835_mm256_fmadd_ps(836vwgt, _mm256_loadu_ps(&ip[j]), _mm256_loadu_ps(&op[j])));837_mm_prefetch(838reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);839}840for (; j < block_size; j++) {841op[j] = std::fma(wgt, ip[j], op[j]);842}843}844if (normalize_by_lengths && length) {845float len_inv = 1.0f / length;846__m256 vlen_inv = _mm256_set1_ps(len_inv);847j = 0;848for (; j + 8 <= block_size; j += 8) {849_mm256_storeu_ps(850&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));851}852for (; j < block_size; j++) {853op[j] = len_inv * op[j];854}855}856}857}858return dataInd == index_size;859}
860bool EmbeddingLookupIdx_int64_t_float_float_false__avx2_fma(861const int64_t block_size,862const int64_t output_size,863const int64_t index_size,864const int64_t data_size,865const float* input,866const int64_t* indices,867const int64_t* offsets,868const float* weights,869const float* scale_bias,870bool normalize_by_lengths,871float* out) {872return EmbeddingLookupIdx_int64_t_float_float__avx2_fma<false>(873block_size,874output_size,875index_size,876data_size,877input,878indices,879offsets,880weights,881scale_bias,882normalize_by_lengths,883out);884}
885bool EmbeddingLookupIdx_int64_t_float_float_true__avx2_fma(886const int64_t block_size,887const int64_t output_size,888const int64_t index_size,889const int64_t data_size,890const float* input,891const int64_t* indices,892const int64_t* offsets,893const float* weights,894const float* scale_bias,895bool normalize_by_lengths,896float* out) {897return EmbeddingLookupIdx_int64_t_float_float__avx2_fma<true>(898block_size,899output_size,900index_size,901data_size,902input,903indices,904offsets,905weights,906scale_bias,907normalize_by_lengths,908out);909}
910
911template <bool IS_WEIGHT_POSITIONAL>912static bool EmbeddingLookupIdx_int32_t_half_float__avx2_fma(913const int64_t block_size,914const int64_t output_size,915const int64_t index_size,916const int64_t data_size,917const at::Half* input,918const int* indices,919const int* offsets,920const float* weights,921const float* scale_bias,922bool normalize_by_lengths,923float* out) {924const int prefdist_T0 = 16;925// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)926const int fused_block_size = block_size + 0;927int64_t dataInd = 0;928if (block_size == 128) {929// unrolling 16 times930for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {931float* op = &out[rangeIndex * block_size];932__m256 vop0 = _mm256_setzero_ps();933__m256 vop8 = _mm256_setzero_ps();934__m256 vop16 = _mm256_setzero_ps();935__m256 vop24 = _mm256_setzero_ps();936__m256 vop32 = _mm256_setzero_ps();937__m256 vop40 = _mm256_setzero_ps();938__m256 vop48 = _mm256_setzero_ps();939__m256 vop56 = _mm256_setzero_ps();940__m256 vop64 = _mm256_setzero_ps();941__m256 vop72 = _mm256_setzero_ps();942__m256 vop80 = _mm256_setzero_ps();943__m256 vop88 = _mm256_setzero_ps();944__m256 vop96 = _mm256_setzero_ps();945__m256 vop104 = _mm256_setzero_ps();946__m256 vop112 = _mm256_setzero_ps();947__m256 vop120 = _mm256_setzero_ps();948if (dataInd != offsets[rangeIndex] - offsets[0]) {949return false;950}951int64_t end_offset = offsets[rangeIndex + 1];952int64_t length = end_offset - offsets[rangeIndex];953for (int64_t start = dataInd; dataInd < end_offset - offsets[0];954++dataInd) {955const int idx = indices[dataInd];956if (idx < 0 || idx >= data_size) {957return false;958}959float wgt = 1.f;960if (weights) {961wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];962}963__m256 vwgt = _mm256_set1_ps(wgt);964const at::Half* ip = &input[idx * fused_block_size];965const int next_T0 = (dataInd < index_size - prefdist_T0)966// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)967? (dataInd + prefdist_T0)968// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)969: dataInd;970const int idx_pref_T0 = indices[next_T0];971if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {972return false;973}974const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];975vop0 = _mm256_fmadd_ps(976vwgt,977_mm256_cvtph_ps(978_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),979vop0);980_mm_prefetch(981reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);982vop8 = _mm256_fmadd_ps(983vwgt,984_mm256_cvtph_ps(985_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),986vop8);987// skip unnecessary prefetch of (&ip_next_T0[8])988vop16 = _mm256_fmadd_ps(989vwgt,990_mm256_cvtph_ps(991_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),992vop16);993// skip unnecessary prefetch of (&ip_next_T0[16])994vop24 = _mm256_fmadd_ps(995vwgt,996_mm256_cvtph_ps(997_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),998vop24);999// skip unnecessary prefetch of (&ip_next_T0[24])1000vop32 = _mm256_fmadd_ps(1001vwgt,1002_mm256_cvtph_ps(1003_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),1004vop32);1005_mm_prefetch(1006reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);1007vop40 = _mm256_fmadd_ps(1008vwgt,1009_mm256_cvtph_ps(1010_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),1011vop40);1012// skip unnecessary prefetch of (&ip_next_T0[40])1013vop48 = _mm256_fmadd_ps(1014vwgt,1015_mm256_cvtph_ps(1016_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),1017vop48);1018// skip unnecessary prefetch of (&ip_next_T0[48])1019vop56 = _mm256_fmadd_ps(1020vwgt,1021_mm256_cvtph_ps(1022_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),1023vop56);1024// skip unnecessary prefetch of (&ip_next_T0[56])1025vop64 = _mm256_fmadd_ps(1026vwgt,1027_mm256_cvtph_ps(1028_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),1029vop64);1030_mm_prefetch(1031reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);1032vop72 = _mm256_fmadd_ps(1033vwgt,1034_mm256_cvtph_ps(1035_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),1036vop72);1037// skip unnecessary prefetch of (&ip_next_T0[72])1038vop80 = _mm256_fmadd_ps(1039vwgt,1040_mm256_cvtph_ps(1041_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),1042vop80);1043// skip unnecessary prefetch of (&ip_next_T0[80])1044vop88 = _mm256_fmadd_ps(1045vwgt,1046_mm256_cvtph_ps(1047_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),1048vop88);1049// skip unnecessary prefetch of (&ip_next_T0[88])1050vop96 = _mm256_fmadd_ps(1051vwgt,1052_mm256_cvtph_ps(1053_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),1054vop96);1055_mm_prefetch(1056reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);1057vop104 = _mm256_fmadd_ps(1058vwgt,1059_mm256_cvtph_ps(1060_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),1061vop104);1062// skip unnecessary prefetch of (&ip_next_T0[104])1063vop112 = _mm256_fmadd_ps(1064vwgt,1065_mm256_cvtph_ps(1066_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),1067vop112);1068// skip unnecessary prefetch of (&ip_next_T0[112])1069vop120 = _mm256_fmadd_ps(1070vwgt,1071_mm256_cvtph_ps(1072_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),1073vop120);1074// skip unnecessary prefetch of (&ip_next_T0[120])1075}1076if (!normalize_by_lengths || length == 0) {1077_mm256_storeu_ps(&op[0], vop0);1078_mm256_storeu_ps(&op[8], vop8);1079_mm256_storeu_ps(&op[16], vop16);1080_mm256_storeu_ps(&op[24], vop24);1081_mm256_storeu_ps(&op[32], vop32);1082_mm256_storeu_ps(&op[40], vop40);1083_mm256_storeu_ps(&op[48], vop48);1084_mm256_storeu_ps(&op[56], vop56);1085_mm256_storeu_ps(&op[64], vop64);1086_mm256_storeu_ps(&op[72], vop72);1087_mm256_storeu_ps(&op[80], vop80);1088_mm256_storeu_ps(&op[88], vop88);1089_mm256_storeu_ps(&op[96], vop96);1090_mm256_storeu_ps(&op[104], vop104);1091_mm256_storeu_ps(&op[112], vop112);1092_mm256_storeu_ps(&op[120], vop120);1093} else {1094__m256 vlen_inv = _mm256_set1_ps(1.0f / length);1095_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));1096_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));1097_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));1098_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));1099_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));1100_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));1101_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));1102_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));1103_mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));1104_mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));1105_mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));1106_mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));1107_mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));1108_mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));1109_mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));1110_mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));1111}1112}1113} else if (block_size == 64) {1114// unrolling 8 times1115for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {1116float* op = &out[rangeIndex * block_size];1117__m256 vop0 = _mm256_setzero_ps();1118__m256 vop8 = _mm256_setzero_ps();1119__m256 vop16 = _mm256_setzero_ps();1120__m256 vop24 = _mm256_setzero_ps();1121__m256 vop32 = _mm256_setzero_ps();1122__m256 vop40 = _mm256_setzero_ps();1123__m256 vop48 = _mm256_setzero_ps();1124__m256 vop56 = _mm256_setzero_ps();1125if (dataInd != offsets[rangeIndex] - offsets[0]) {1126return false;1127}1128int64_t end_offset = offsets[rangeIndex + 1];1129int64_t length = end_offset - offsets[rangeIndex];1130for (int64_t start = dataInd; dataInd < end_offset - offsets[0];1131++dataInd) {1132const int idx = indices[dataInd];1133if (idx < 0 || idx >= data_size) {1134return false;1135}1136float wgt = 1.f;1137if (weights) {1138wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];1139}1140__m256 vwgt = _mm256_set1_ps(wgt);1141const at::Half* ip = &input[idx * fused_block_size];1142const int next_T0 = (dataInd < index_size - prefdist_T0)1143// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1144? (dataInd + prefdist_T0)1145// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1146: dataInd;1147const int idx_pref_T0 = indices[next_T0];1148if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {1149return false;1150}1151const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];1152vop0 = _mm256_fmadd_ps(1153vwgt,1154_mm256_cvtph_ps(1155_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),1156vop0);1157_mm_prefetch(1158reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);1159vop8 = _mm256_fmadd_ps(1160vwgt,1161_mm256_cvtph_ps(1162_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),1163vop8);1164// skip unnecessary prefetch of (&ip_next_T0[8])1165vop16 = _mm256_fmadd_ps(1166vwgt,1167_mm256_cvtph_ps(1168_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),1169vop16);1170// skip unnecessary prefetch of (&ip_next_T0[16])1171vop24 = _mm256_fmadd_ps(1172vwgt,1173_mm256_cvtph_ps(1174_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),1175vop24);1176// skip unnecessary prefetch of (&ip_next_T0[24])1177vop32 = _mm256_fmadd_ps(1178vwgt,1179_mm256_cvtph_ps(1180_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),1181vop32);1182_mm_prefetch(1183reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);1184vop40 = _mm256_fmadd_ps(1185vwgt,1186_mm256_cvtph_ps(1187_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),1188vop40);1189// skip unnecessary prefetch of (&ip_next_T0[40])1190vop48 = _mm256_fmadd_ps(1191vwgt,1192_mm256_cvtph_ps(1193_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),1194vop48);1195// skip unnecessary prefetch of (&ip_next_T0[48])1196vop56 = _mm256_fmadd_ps(1197vwgt,1198_mm256_cvtph_ps(1199_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),1200vop56);1201// skip unnecessary prefetch of (&ip_next_T0[56])1202}1203if (!normalize_by_lengths || length == 0) {1204_mm256_storeu_ps(&op[0], vop0);1205_mm256_storeu_ps(&op[8], vop8);1206_mm256_storeu_ps(&op[16], vop16);1207_mm256_storeu_ps(&op[24], vop24);1208_mm256_storeu_ps(&op[32], vop32);1209_mm256_storeu_ps(&op[40], vop40);1210_mm256_storeu_ps(&op[48], vop48);1211_mm256_storeu_ps(&op[56], vop56);1212} else {1213__m256 vlen_inv = _mm256_set1_ps(1.0f / length);1214_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));1215_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));1216_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));1217_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));1218_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));1219_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));1220_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));1221_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));1222}1223}1224} else if (block_size == 32) {1225// unrolling 4 times1226for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {1227float* op = &out[rangeIndex * block_size];1228__m256 vop0 = _mm256_setzero_ps();1229__m256 vop8 = _mm256_setzero_ps();1230__m256 vop16 = _mm256_setzero_ps();1231__m256 vop24 = _mm256_setzero_ps();1232if (dataInd != offsets[rangeIndex] - offsets[0]) {1233return false;1234}1235int64_t end_offset = offsets[rangeIndex + 1];1236int64_t length = end_offset - offsets[rangeIndex];1237for (int64_t start = dataInd; dataInd < end_offset - offsets[0];1238++dataInd) {1239const int idx = indices[dataInd];1240if (idx < 0 || idx >= data_size) {1241return false;1242}1243float wgt = 1.f;1244if (weights) {1245wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];1246}1247__m256 vwgt = _mm256_set1_ps(wgt);1248const at::Half* ip = &input[idx * fused_block_size];1249const int next_T0 = (dataInd < index_size - prefdist_T0)1250// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1251? (dataInd + prefdist_T0)1252// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1253: dataInd;1254const int idx_pref_T0 = indices[next_T0];1255if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {1256return false;1257}1258const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];1259vop0 = _mm256_fmadd_ps(1260vwgt,1261_mm256_cvtph_ps(1262_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),1263vop0);1264_mm_prefetch(1265reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);1266vop8 = _mm256_fmadd_ps(1267vwgt,1268_mm256_cvtph_ps(1269_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),1270vop8);1271// skip unnecessary prefetch of (&ip_next_T0[8])1272vop16 = _mm256_fmadd_ps(1273vwgt,1274_mm256_cvtph_ps(1275_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),1276vop16);1277// skip unnecessary prefetch of (&ip_next_T0[16])1278vop24 = _mm256_fmadd_ps(1279vwgt,1280_mm256_cvtph_ps(1281_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),1282vop24);1283// skip unnecessary prefetch of (&ip_next_T0[24])1284}1285if (!normalize_by_lengths || length == 0) {1286_mm256_storeu_ps(&op[0], vop0);1287_mm256_storeu_ps(&op[8], vop8);1288_mm256_storeu_ps(&op[16], vop16);1289_mm256_storeu_ps(&op[24], vop24);1290} else {1291__m256 vlen_inv = _mm256_set1_ps(1.0f / length);1292_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));1293_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));1294_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));1295_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));1296}1297}1298} else if (block_size == 16) {1299// unrolling 2 times1300for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {1301float* op = &out[rangeIndex * block_size];1302__m256 vop0 = _mm256_setzero_ps();1303__m256 vop8 = _mm256_setzero_ps();1304if (dataInd != offsets[rangeIndex] - offsets[0]) {1305return false;1306}1307int64_t end_offset = offsets[rangeIndex + 1];1308int64_t length = end_offset - offsets[rangeIndex];1309for (int64_t start = dataInd; dataInd < end_offset - offsets[0];1310++dataInd) {1311const int idx = indices[dataInd];1312if (idx < 0 || idx >= data_size) {1313return false;1314}1315float wgt = 1.f;1316if (weights) {1317wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];1318}1319__m256 vwgt = _mm256_set1_ps(wgt);1320const at::Half* ip = &input[idx * fused_block_size];1321const int next_T0 = (dataInd < index_size - prefdist_T0)1322// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1323? (dataInd + prefdist_T0)1324// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1325: dataInd;1326const int idx_pref_T0 = indices[next_T0];1327if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {1328return false;1329}1330const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];1331vop0 = _mm256_fmadd_ps(1332vwgt,1333_mm256_cvtph_ps(1334_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),1335vop0);1336_mm_prefetch(1337reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);1338vop8 = _mm256_fmadd_ps(1339vwgt,1340_mm256_cvtph_ps(1341_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),1342vop8);1343// skip unnecessary prefetch of (&ip_next_T0[8])1344}1345if (!normalize_by_lengths || length == 0) {1346_mm256_storeu_ps(&op[0], vop0);1347_mm256_storeu_ps(&op[8], vop8);1348} else {1349__m256 vlen_inv = _mm256_set1_ps(1.0f / length);1350_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));1351_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));1352}1353}1354} else {1355// generic code1356// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)1357alignas(64) at::Half vtmp1[8] = {0};1358for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {1359float* op = &out[rangeIndex * block_size];1360int64_t j = 0;1361for (; j + 8 <= block_size; j += 8) {1362_mm256_storeu_ps(op + j, _mm256_setzero_ps());1363}1364for (; j < block_size; j++) {1365op[j] = 0.0f;1366}1367if (dataInd != offsets[rangeIndex] - offsets[0]) {1368return false;1369}1370int64_t end_offset = offsets[rangeIndex + 1];1371int64_t length = end_offset - offsets[rangeIndex];1372for (int64_t start = dataInd; dataInd < end_offset - offsets[0];1373++dataInd) {1374const int idx = indices[dataInd];1375if (idx < 0 || idx >= data_size) {1376return false;1377}1378float wgt = 1.f;1379if (weights) {1380wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];1381}1382__m256 vwgt = _mm256_set1_ps(wgt);1383const at::Half* ip = &input[idx * fused_block_size];1384const int next_T0 = (dataInd < index_size - prefdist_T0)1385// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1386? (dataInd + prefdist_T0)1387// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1388: dataInd;1389const int idx_pref_T0 = indices[next_T0];1390if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {1391return false;1392}1393const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];1394j = 0;1395for (; j + 8 <= block_size; j += 8) {1396_mm256_storeu_ps(1397&op[j],1398_mm256_fmadd_ps(1399vwgt,1400_mm256_cvtph_ps(_mm_loadu_si128(1401reinterpret_cast<const __m128i*>(&ip[j]))),1402_mm256_loadu_ps(&op[j])));1403_mm_prefetch(1404reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);1405}1406for (; j < block_size; j++) {1407vtmp1[0] = ip[j];1408__m256 vtmp2 =1409_mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));1410op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);1411}1412}1413if (normalize_by_lengths && length) {1414float len_inv = 1.0f / length;1415__m256 vlen_inv = _mm256_set1_ps(len_inv);1416j = 0;1417for (; j + 8 <= block_size; j += 8) {1418_mm256_storeu_ps(1419&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));1420}1421for (; j < block_size; j++) {1422op[j] = len_inv * op[j];1423}1424}1425}1426}1427return dataInd == index_size;1428}
1429bool EmbeddingLookupIdx_int32_t_half_float_false__avx2_fma(1430const int64_t block_size,1431const int64_t output_size,1432const int64_t index_size,1433const int64_t data_size,1434const at::Half* input,1435const int* indices,1436const int* offsets,1437const float* weights,1438const float* scale_bias,1439bool normalize_by_lengths,1440float* out) {1441return EmbeddingLookupIdx_int32_t_half_float__avx2_fma<false>(1442block_size,1443output_size,1444index_size,1445data_size,1446input,1447indices,1448offsets,1449weights,1450scale_bias,1451normalize_by_lengths,1452out);1453}
1454bool EmbeddingLookupIdx_int32_t_half_float_true__avx2_fma(1455const int64_t block_size,1456const int64_t output_size,1457const int64_t index_size,1458const int64_t data_size,1459const at::Half* input,1460const int* indices,1461const int* offsets,1462const float* weights,1463const float* scale_bias,1464bool normalize_by_lengths,1465float* out) {1466return EmbeddingLookupIdx_int32_t_half_float__avx2_fma<true>(1467block_size,1468output_size,1469index_size,1470data_size,1471input,1472indices,1473offsets,1474weights,1475scale_bias,1476normalize_by_lengths,1477out);1478}
1479
1480template <bool IS_WEIGHT_POSITIONAL>1481static bool EmbeddingLookupIdx_int64_t_half_float__avx2_fma(1482const int64_t block_size,1483const int64_t output_size,1484const int64_t index_size,1485const int64_t data_size,1486const at::Half* input,1487const int64_t* indices,1488const int64_t* offsets,1489const float* weights,1490const float* scale_bias,1491bool normalize_by_lengths,1492float* out) {1493const int64_t prefdist_T0 = 16;1494// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1495const int64_t fused_block_size = block_size + 0;1496int64_t dataInd = 0;1497if (block_size == 128) {1498// unrolling 16 times1499for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {1500float* op = &out[rangeIndex * block_size];1501__m256 vop0 = _mm256_setzero_ps();1502__m256 vop8 = _mm256_setzero_ps();1503__m256 vop16 = _mm256_setzero_ps();1504__m256 vop24 = _mm256_setzero_ps();1505__m256 vop32 = _mm256_setzero_ps();1506__m256 vop40 = _mm256_setzero_ps();1507__m256 vop48 = _mm256_setzero_ps();1508__m256 vop56 = _mm256_setzero_ps();1509__m256 vop64 = _mm256_setzero_ps();1510__m256 vop72 = _mm256_setzero_ps();1511__m256 vop80 = _mm256_setzero_ps();1512__m256 vop88 = _mm256_setzero_ps();1513__m256 vop96 = _mm256_setzero_ps();1514__m256 vop104 = _mm256_setzero_ps();1515__m256 vop112 = _mm256_setzero_ps();1516__m256 vop120 = _mm256_setzero_ps();1517if (dataInd != offsets[rangeIndex] - offsets[0]) {1518return false;1519}1520int64_t end_offset = offsets[rangeIndex + 1];1521int64_t length = end_offset - offsets[rangeIndex];1522for (int64_t start = dataInd; dataInd < end_offset - offsets[0];1523++dataInd) {1524const int64_t idx = indices[dataInd];1525if (idx < 0 || idx >= data_size) {1526return false;1527}1528float wgt = 1.f;1529if (weights) {1530wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];1531}1532__m256 vwgt = _mm256_set1_ps(wgt);1533const at::Half* ip = &input[idx * fused_block_size];1534const int64_t next_T0 = (dataInd < index_size - prefdist_T0)1535// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1536? (dataInd + prefdist_T0)1537// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1538: dataInd;1539const int64_t idx_pref_T0 = indices[next_T0];1540if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {1541return false;1542}1543const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];1544vop0 = _mm256_fmadd_ps(1545vwgt,1546_mm256_cvtph_ps(1547_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),1548vop0);1549_mm_prefetch(1550reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);1551vop8 = _mm256_fmadd_ps(1552vwgt,1553_mm256_cvtph_ps(1554_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),1555vop8);1556// skip unnecessary prefetch of (&ip_next_T0[8])1557vop16 = _mm256_fmadd_ps(1558vwgt,1559_mm256_cvtph_ps(1560_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),1561vop16);1562// skip unnecessary prefetch of (&ip_next_T0[16])1563vop24 = _mm256_fmadd_ps(1564vwgt,1565_mm256_cvtph_ps(1566_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),1567vop24);1568// skip unnecessary prefetch of (&ip_next_T0[24])1569vop32 = _mm256_fmadd_ps(1570vwgt,1571_mm256_cvtph_ps(1572_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),1573vop32);1574_mm_prefetch(1575reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);1576vop40 = _mm256_fmadd_ps(1577vwgt,1578_mm256_cvtph_ps(1579_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),1580vop40);1581// skip unnecessary prefetch of (&ip_next_T0[40])1582vop48 = _mm256_fmadd_ps(1583vwgt,1584_mm256_cvtph_ps(1585_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),1586vop48);1587// skip unnecessary prefetch of (&ip_next_T0[48])1588vop56 = _mm256_fmadd_ps(1589vwgt,1590_mm256_cvtph_ps(1591_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),1592vop56);1593// skip unnecessary prefetch of (&ip_next_T0[56])1594vop64 = _mm256_fmadd_ps(1595vwgt,1596_mm256_cvtph_ps(1597_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (64)))),1598vop64);1599_mm_prefetch(1600reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);1601vop72 = _mm256_fmadd_ps(1602vwgt,1603_mm256_cvtph_ps(1604_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (72)))),1605vop72);1606// skip unnecessary prefetch of (&ip_next_T0[72])1607vop80 = _mm256_fmadd_ps(1608vwgt,1609_mm256_cvtph_ps(1610_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (80)))),1611vop80);1612// skip unnecessary prefetch of (&ip_next_T0[80])1613vop88 = _mm256_fmadd_ps(1614vwgt,1615_mm256_cvtph_ps(1616_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (88)))),1617vop88);1618// skip unnecessary prefetch of (&ip_next_T0[88])1619vop96 = _mm256_fmadd_ps(1620vwgt,1621_mm256_cvtph_ps(1622_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (96)))),1623vop96);1624_mm_prefetch(1625reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);1626vop104 = _mm256_fmadd_ps(1627vwgt,1628_mm256_cvtph_ps(1629_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (104)))),1630vop104);1631// skip unnecessary prefetch of (&ip_next_T0[104])1632vop112 = _mm256_fmadd_ps(1633vwgt,1634_mm256_cvtph_ps(1635_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (112)))),1636vop112);1637// skip unnecessary prefetch of (&ip_next_T0[112])1638vop120 = _mm256_fmadd_ps(1639vwgt,1640_mm256_cvtph_ps(1641_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (120)))),1642vop120);1643// skip unnecessary prefetch of (&ip_next_T0[120])1644}1645if (!normalize_by_lengths || length == 0) {1646_mm256_storeu_ps(&op[0], vop0);1647_mm256_storeu_ps(&op[8], vop8);1648_mm256_storeu_ps(&op[16], vop16);1649_mm256_storeu_ps(&op[24], vop24);1650_mm256_storeu_ps(&op[32], vop32);1651_mm256_storeu_ps(&op[40], vop40);1652_mm256_storeu_ps(&op[48], vop48);1653_mm256_storeu_ps(&op[56], vop56);1654_mm256_storeu_ps(&op[64], vop64);1655_mm256_storeu_ps(&op[72], vop72);1656_mm256_storeu_ps(&op[80], vop80);1657_mm256_storeu_ps(&op[88], vop88);1658_mm256_storeu_ps(&op[96], vop96);1659_mm256_storeu_ps(&op[104], vop104);1660_mm256_storeu_ps(&op[112], vop112);1661_mm256_storeu_ps(&op[120], vop120);1662} else {1663__m256 vlen_inv = _mm256_set1_ps(1.0f / length);1664_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));1665_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));1666_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));1667_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));1668_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));1669_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));1670_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));1671_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));1672_mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));1673_mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));1674_mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));1675_mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));1676_mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));1677_mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));1678_mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));1679_mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));1680}1681}1682} else if (block_size == 64) {1683// unrolling 8 times1684for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {1685float* op = &out[rangeIndex * block_size];1686__m256 vop0 = _mm256_setzero_ps();1687__m256 vop8 = _mm256_setzero_ps();1688__m256 vop16 = _mm256_setzero_ps();1689__m256 vop24 = _mm256_setzero_ps();1690__m256 vop32 = _mm256_setzero_ps();1691__m256 vop40 = _mm256_setzero_ps();1692__m256 vop48 = _mm256_setzero_ps();1693__m256 vop56 = _mm256_setzero_ps();1694if (dataInd != offsets[rangeIndex] - offsets[0]) {1695return false;1696}1697int64_t end_offset = offsets[rangeIndex + 1];1698int64_t length = end_offset - offsets[rangeIndex];1699for (int64_t start = dataInd; dataInd < end_offset - offsets[0];1700++dataInd) {1701const int64_t idx = indices[dataInd];1702if (idx < 0 || idx >= data_size) {1703return false;1704}1705float wgt = 1.f;1706if (weights) {1707wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];1708}1709__m256 vwgt = _mm256_set1_ps(wgt);1710const at::Half* ip = &input[idx * fused_block_size];1711const int64_t next_T0 = (dataInd < index_size - prefdist_T0)1712// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1713? (dataInd + prefdist_T0)1714// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1715: dataInd;1716const int64_t idx_pref_T0 = indices[next_T0];1717if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {1718return false;1719}1720const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];1721vop0 = _mm256_fmadd_ps(1722vwgt,1723_mm256_cvtph_ps(1724_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),1725vop0);1726_mm_prefetch(1727reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);1728vop8 = _mm256_fmadd_ps(1729vwgt,1730_mm256_cvtph_ps(1731_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),1732vop8);1733// skip unnecessary prefetch of (&ip_next_T0[8])1734vop16 = _mm256_fmadd_ps(1735vwgt,1736_mm256_cvtph_ps(1737_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),1738vop16);1739// skip unnecessary prefetch of (&ip_next_T0[16])1740vop24 = _mm256_fmadd_ps(1741vwgt,1742_mm256_cvtph_ps(1743_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),1744vop24);1745// skip unnecessary prefetch of (&ip_next_T0[24])1746vop32 = _mm256_fmadd_ps(1747vwgt,1748_mm256_cvtph_ps(1749_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (32)))),1750vop32);1751_mm_prefetch(1752reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);1753vop40 = _mm256_fmadd_ps(1754vwgt,1755_mm256_cvtph_ps(1756_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (40)))),1757vop40);1758// skip unnecessary prefetch of (&ip_next_T0[40])1759vop48 = _mm256_fmadd_ps(1760vwgt,1761_mm256_cvtph_ps(1762_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (48)))),1763vop48);1764// skip unnecessary prefetch of (&ip_next_T0[48])1765vop56 = _mm256_fmadd_ps(1766vwgt,1767_mm256_cvtph_ps(1768_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (56)))),1769vop56);1770// skip unnecessary prefetch of (&ip_next_T0[56])1771}1772if (!normalize_by_lengths || length == 0) {1773_mm256_storeu_ps(&op[0], vop0);1774_mm256_storeu_ps(&op[8], vop8);1775_mm256_storeu_ps(&op[16], vop16);1776_mm256_storeu_ps(&op[24], vop24);1777_mm256_storeu_ps(&op[32], vop32);1778_mm256_storeu_ps(&op[40], vop40);1779_mm256_storeu_ps(&op[48], vop48);1780_mm256_storeu_ps(&op[56], vop56);1781} else {1782__m256 vlen_inv = _mm256_set1_ps(1.0f / length);1783_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));1784_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));1785_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));1786_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));1787_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));1788_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));1789_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));1790_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));1791}1792}1793} else if (block_size == 32) {1794// unrolling 4 times1795for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {1796float* op = &out[rangeIndex * block_size];1797__m256 vop0 = _mm256_setzero_ps();1798__m256 vop8 = _mm256_setzero_ps();1799__m256 vop16 = _mm256_setzero_ps();1800__m256 vop24 = _mm256_setzero_ps();1801if (dataInd != offsets[rangeIndex] - offsets[0]) {1802return false;1803}1804int64_t end_offset = offsets[rangeIndex + 1];1805int64_t length = end_offset - offsets[rangeIndex];1806for (int64_t start = dataInd; dataInd < end_offset - offsets[0];1807++dataInd) {1808const int64_t idx = indices[dataInd];1809if (idx < 0 || idx >= data_size) {1810return false;1811}1812float wgt = 1.f;1813if (weights) {1814wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];1815}1816__m256 vwgt = _mm256_set1_ps(wgt);1817const at::Half* ip = &input[idx * fused_block_size];1818const int64_t next_T0 = (dataInd < index_size - prefdist_T0)1819// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1820? (dataInd + prefdist_T0)1821// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1822: dataInd;1823const int64_t idx_pref_T0 = indices[next_T0];1824if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {1825return false;1826}1827const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];1828vop0 = _mm256_fmadd_ps(1829vwgt,1830_mm256_cvtph_ps(1831_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),1832vop0);1833_mm_prefetch(1834reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);1835vop8 = _mm256_fmadd_ps(1836vwgt,1837_mm256_cvtph_ps(1838_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),1839vop8);1840// skip unnecessary prefetch of (&ip_next_T0[8])1841vop16 = _mm256_fmadd_ps(1842vwgt,1843_mm256_cvtph_ps(1844_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (16)))),1845vop16);1846// skip unnecessary prefetch of (&ip_next_T0[16])1847vop24 = _mm256_fmadd_ps(1848vwgt,1849_mm256_cvtph_ps(1850_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (24)))),1851vop24);1852// skip unnecessary prefetch of (&ip_next_T0[24])1853}1854if (!normalize_by_lengths || length == 0) {1855_mm256_storeu_ps(&op[0], vop0);1856_mm256_storeu_ps(&op[8], vop8);1857_mm256_storeu_ps(&op[16], vop16);1858_mm256_storeu_ps(&op[24], vop24);1859} else {1860__m256 vlen_inv = _mm256_set1_ps(1.0f / length);1861_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));1862_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));1863_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));1864_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));1865}1866}1867} else if (block_size == 16) {1868// unrolling 2 times1869for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {1870float* op = &out[rangeIndex * block_size];1871__m256 vop0 = _mm256_setzero_ps();1872__m256 vop8 = _mm256_setzero_ps();1873if (dataInd != offsets[rangeIndex] - offsets[0]) {1874return false;1875}1876int64_t end_offset = offsets[rangeIndex + 1];1877int64_t length = end_offset - offsets[rangeIndex];1878for (int64_t start = dataInd; dataInd < end_offset - offsets[0];1879++dataInd) {1880const int64_t idx = indices[dataInd];1881if (idx < 0 || idx >= data_size) {1882return false;1883}1884float wgt = 1.f;1885if (weights) {1886wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];1887}1888__m256 vwgt = _mm256_set1_ps(wgt);1889const at::Half* ip = &input[idx * fused_block_size];1890const int64_t next_T0 = (dataInd < index_size - prefdist_T0)1891// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1892? (dataInd + prefdist_T0)1893// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1894: dataInd;1895const int64_t idx_pref_T0 = indices[next_T0];1896if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {1897return false;1898}1899const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];1900vop0 = _mm256_fmadd_ps(1901vwgt,1902_mm256_cvtph_ps(1903_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (0)))),1904vop0);1905_mm_prefetch(1906reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);1907vop8 = _mm256_fmadd_ps(1908vwgt,1909_mm256_cvtph_ps(1910_mm_loadu_si128(reinterpret_cast<const __m128i*>(ip + (8)))),1911vop8);1912// skip unnecessary prefetch of (&ip_next_T0[8])1913}1914if (!normalize_by_lengths || length == 0) {1915_mm256_storeu_ps(&op[0], vop0);1916_mm256_storeu_ps(&op[8], vop8);1917} else {1918__m256 vlen_inv = _mm256_set1_ps(1.0f / length);1919_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));1920_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));1921}1922}1923} else {1924// generic code1925// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)1926alignas(64) at::Half vtmp1[8] = {0};1927for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {1928float* op = &out[rangeIndex * block_size];1929int64_t j = 0;1930for (; j + 8 <= block_size; j += 8) {1931_mm256_storeu_ps(op + j, _mm256_setzero_ps());1932}1933for (; j < block_size; j++) {1934op[j] = 0.0f;1935}1936if (dataInd != offsets[rangeIndex] - offsets[0]) {1937return false;1938}1939int64_t end_offset = offsets[rangeIndex + 1];1940int64_t length = end_offset - offsets[rangeIndex];1941for (int64_t start = dataInd; dataInd < end_offset - offsets[0];1942++dataInd) {1943const int64_t idx = indices[dataInd];1944if (idx < 0 || idx >= data_size) {1945return false;1946}1947float wgt = 1.f;1948if (weights) {1949wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];1950}1951__m256 vwgt = _mm256_set1_ps(wgt);1952const at::Half* ip = &input[idx * fused_block_size];1953const int64_t next_T0 = (dataInd < index_size - prefdist_T0)1954// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1955? (dataInd + prefdist_T0)1956// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)1957: dataInd;1958const int64_t idx_pref_T0 = indices[next_T0];1959if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {1960return false;1961}1962const at::Half* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];1963j = 0;1964for (; j + 8 <= block_size; j += 8) {1965_mm256_storeu_ps(1966&op[j],1967_mm256_fmadd_ps(1968vwgt,1969_mm256_cvtph_ps(_mm_loadu_si128(1970reinterpret_cast<const __m128i*>(&ip[j]))),1971_mm256_loadu_ps(&op[j])));1972_mm_prefetch(1973reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);1974}1975for (; j < block_size; j++) {1976vtmp1[0] = ip[j];1977__m256 vtmp2 =1978_mm256_cvtph_ps(*(reinterpret_cast<const __m128i*>(vtmp1)));1979op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);1980}1981}1982if (normalize_by_lengths && length) {1983float len_inv = 1.0f / length;1984__m256 vlen_inv = _mm256_set1_ps(len_inv);1985j = 0;1986for (; j + 8 <= block_size; j += 8) {1987_mm256_storeu_ps(1988&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));1989}1990for (; j < block_size; j++) {1991op[j] = len_inv * op[j];1992}1993}1994}1995}1996return dataInd == index_size;1997}
1998bool EmbeddingLookupIdx_int64_t_half_float_false__avx2_fma(1999const int64_t block_size,2000const int64_t output_size,2001const int64_t index_size,2002const int64_t data_size,2003const at::Half* input,2004const int64_t* indices,2005const int64_t* offsets,2006const float* weights,2007const float* scale_bias,2008bool normalize_by_lengths,2009float* out) {2010return EmbeddingLookupIdx_int64_t_half_float__avx2_fma<false>(2011block_size,2012output_size,2013index_size,2014data_size,2015input,2016indices,2017offsets,2018weights,2019scale_bias,2020normalize_by_lengths,2021out);2022}
2023bool EmbeddingLookupIdx_int64_t_half_float_true__avx2_fma(2024const int64_t block_size,2025const int64_t output_size,2026const int64_t index_size,2027const int64_t data_size,2028const at::Half* input,2029const int64_t* indices,2030const int64_t* offsets,2031const float* weights,2032const float* scale_bias,2033bool normalize_by_lengths,2034float* out) {2035return EmbeddingLookupIdx_int64_t_half_float__avx2_fma<true>(2036block_size,2037output_size,2038index_size,2039data_size,2040input,2041indices,2042offsets,2043weights,2044scale_bias,2045normalize_by_lengths,2046out);2047}
2048
2049template <bool IS_WEIGHT_POSITIONAL>2050static bool EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma(2051const int64_t block_size,2052const int64_t output_size,2053const int64_t index_size,2054const int64_t data_size,2055const at::BFloat16* input,2056const int* indices,2057const int* offsets,2058const float* weights,2059const float* scale_bias,2060bool normalize_by_lengths,2061float* out) {2062const int prefdist_T0 = 16;2063// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2064const int fused_block_size = block_size + 0;2065int64_t dataInd = 0;2066if (block_size == 128) {2067// unrolling 16 times2068for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {2069float* op = &out[rangeIndex * block_size];2070__m256 vop0 = _mm256_setzero_ps();2071__m256 vop8 = _mm256_setzero_ps();2072__m256 vop16 = _mm256_setzero_ps();2073__m256 vop24 = _mm256_setzero_ps();2074__m256 vop32 = _mm256_setzero_ps();2075__m256 vop40 = _mm256_setzero_ps();2076__m256 vop48 = _mm256_setzero_ps();2077__m256 vop56 = _mm256_setzero_ps();2078__m256 vop64 = _mm256_setzero_ps();2079__m256 vop72 = _mm256_setzero_ps();2080__m256 vop80 = _mm256_setzero_ps();2081__m256 vop88 = _mm256_setzero_ps();2082__m256 vop96 = _mm256_setzero_ps();2083__m256 vop104 = _mm256_setzero_ps();2084__m256 vop112 = _mm256_setzero_ps();2085__m256 vop120 = _mm256_setzero_ps();2086if (dataInd != offsets[rangeIndex] - offsets[0]) {2087return false;2088}2089int64_t end_offset = offsets[rangeIndex + 1];2090int64_t length = end_offset - offsets[rangeIndex];2091for (int64_t start = dataInd; dataInd < end_offset - offsets[0];2092++dataInd) {2093const int idx = indices[dataInd];2094if (idx < 0 || idx >= data_size) {2095return false;2096}2097float wgt = 1.f;2098if (weights) {2099wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];2100}2101__m256 vwgt = _mm256_set1_ps(wgt);2102const at::BFloat16* ip = &input[idx * fused_block_size];2103const int next_T0 = (dataInd < index_size - prefdist_T0)2104// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2105? (dataInd + prefdist_T0)2106// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2107: dataInd;2108const int idx_pref_T0 = indices[next_T0];2109if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {2110return false;2111}2112const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];2113vop0 = _mm256_fmadd_ps(2114vwgt,2115_mm256_castsi256_ps(_mm256_slli_epi32(2116_mm256_cvtepu16_epi32(_mm_loadu_si128(2117reinterpret_cast<const __m128i*>(ip + (0)))),211816)),2119vop0);2120_mm_prefetch(2121reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);2122vop8 = _mm256_fmadd_ps(2123vwgt,2124_mm256_castsi256_ps(_mm256_slli_epi32(2125_mm256_cvtepu16_epi32(_mm_loadu_si128(2126reinterpret_cast<const __m128i*>(ip + (8)))),212716)),2128vop8);2129// skip unnecessary prefetch of (&ip_next_T0[8])2130vop16 = _mm256_fmadd_ps(2131vwgt,2132_mm256_castsi256_ps(_mm256_slli_epi32(2133_mm256_cvtepu16_epi32(_mm_loadu_si128(2134reinterpret_cast<const __m128i*>(ip + (16)))),213516)),2136vop16);2137// skip unnecessary prefetch of (&ip_next_T0[16])2138vop24 = _mm256_fmadd_ps(2139vwgt,2140_mm256_castsi256_ps(_mm256_slli_epi32(2141_mm256_cvtepu16_epi32(_mm_loadu_si128(2142reinterpret_cast<const __m128i*>(ip + (24)))),214316)),2144vop24);2145// skip unnecessary prefetch of (&ip_next_T0[24])2146vop32 = _mm256_fmadd_ps(2147vwgt,2148_mm256_castsi256_ps(_mm256_slli_epi32(2149_mm256_cvtepu16_epi32(_mm_loadu_si128(2150reinterpret_cast<const __m128i*>(ip + (32)))),215116)),2152vop32);2153_mm_prefetch(2154reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);2155vop40 = _mm256_fmadd_ps(2156vwgt,2157_mm256_castsi256_ps(_mm256_slli_epi32(2158_mm256_cvtepu16_epi32(_mm_loadu_si128(2159reinterpret_cast<const __m128i*>(ip + (40)))),216016)),2161vop40);2162// skip unnecessary prefetch of (&ip_next_T0[40])2163vop48 = _mm256_fmadd_ps(2164vwgt,2165_mm256_castsi256_ps(_mm256_slli_epi32(2166_mm256_cvtepu16_epi32(_mm_loadu_si128(2167reinterpret_cast<const __m128i*>(ip + (48)))),216816)),2169vop48);2170// skip unnecessary prefetch of (&ip_next_T0[48])2171vop56 = _mm256_fmadd_ps(2172vwgt,2173_mm256_castsi256_ps(_mm256_slli_epi32(2174_mm256_cvtepu16_epi32(_mm_loadu_si128(2175reinterpret_cast<const __m128i*>(ip + (56)))),217616)),2177vop56);2178// skip unnecessary prefetch of (&ip_next_T0[56])2179vop64 = _mm256_fmadd_ps(2180vwgt,2181_mm256_castsi256_ps(_mm256_slli_epi32(2182_mm256_cvtepu16_epi32(_mm_loadu_si128(2183reinterpret_cast<const __m128i*>(ip + (64)))),218416)),2185vop64);2186_mm_prefetch(2187reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);2188vop72 = _mm256_fmadd_ps(2189vwgt,2190_mm256_castsi256_ps(_mm256_slli_epi32(2191_mm256_cvtepu16_epi32(_mm_loadu_si128(2192reinterpret_cast<const __m128i*>(ip + (72)))),219316)),2194vop72);2195// skip unnecessary prefetch of (&ip_next_T0[72])2196vop80 = _mm256_fmadd_ps(2197vwgt,2198_mm256_castsi256_ps(_mm256_slli_epi32(2199_mm256_cvtepu16_epi32(_mm_loadu_si128(2200reinterpret_cast<const __m128i*>(ip + (80)))),220116)),2202vop80);2203// skip unnecessary prefetch of (&ip_next_T0[80])2204vop88 = _mm256_fmadd_ps(2205vwgt,2206_mm256_castsi256_ps(_mm256_slli_epi32(2207_mm256_cvtepu16_epi32(_mm_loadu_si128(2208reinterpret_cast<const __m128i*>(ip + (88)))),220916)),2210vop88);2211// skip unnecessary prefetch of (&ip_next_T0[88])2212vop96 = _mm256_fmadd_ps(2213vwgt,2214_mm256_castsi256_ps(_mm256_slli_epi32(2215_mm256_cvtepu16_epi32(_mm_loadu_si128(2216reinterpret_cast<const __m128i*>(ip + (96)))),221716)),2218vop96);2219_mm_prefetch(2220reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);2221vop104 = _mm256_fmadd_ps(2222vwgt,2223_mm256_castsi256_ps(_mm256_slli_epi32(2224_mm256_cvtepu16_epi32(_mm_loadu_si128(2225reinterpret_cast<const __m128i*>(ip + (104)))),222616)),2227vop104);2228// skip unnecessary prefetch of (&ip_next_T0[104])2229vop112 = _mm256_fmadd_ps(2230vwgt,2231_mm256_castsi256_ps(_mm256_slli_epi32(2232_mm256_cvtepu16_epi32(_mm_loadu_si128(2233reinterpret_cast<const __m128i*>(ip + (112)))),223416)),2235vop112);2236// skip unnecessary prefetch of (&ip_next_T0[112])2237vop120 = _mm256_fmadd_ps(2238vwgt,2239_mm256_castsi256_ps(_mm256_slli_epi32(2240_mm256_cvtepu16_epi32(_mm_loadu_si128(2241reinterpret_cast<const __m128i*>(ip + (120)))),224216)),2243vop120);2244// skip unnecessary prefetch of (&ip_next_T0[120])2245}2246if (!normalize_by_lengths || length == 0) {2247_mm256_storeu_ps(&op[0], vop0);2248_mm256_storeu_ps(&op[8], vop8);2249_mm256_storeu_ps(&op[16], vop16);2250_mm256_storeu_ps(&op[24], vop24);2251_mm256_storeu_ps(&op[32], vop32);2252_mm256_storeu_ps(&op[40], vop40);2253_mm256_storeu_ps(&op[48], vop48);2254_mm256_storeu_ps(&op[56], vop56);2255_mm256_storeu_ps(&op[64], vop64);2256_mm256_storeu_ps(&op[72], vop72);2257_mm256_storeu_ps(&op[80], vop80);2258_mm256_storeu_ps(&op[88], vop88);2259_mm256_storeu_ps(&op[96], vop96);2260_mm256_storeu_ps(&op[104], vop104);2261_mm256_storeu_ps(&op[112], vop112);2262_mm256_storeu_ps(&op[120], vop120);2263} else {2264__m256 vlen_inv = _mm256_set1_ps(1.0f / length);2265_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));2266_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));2267_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));2268_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));2269_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));2270_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));2271_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));2272_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));2273_mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));2274_mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));2275_mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));2276_mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));2277_mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));2278_mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));2279_mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));2280_mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));2281}2282}2283} else if (block_size == 64) {2284// unrolling 8 times2285for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {2286float* op = &out[rangeIndex * block_size];2287__m256 vop0 = _mm256_setzero_ps();2288__m256 vop8 = _mm256_setzero_ps();2289__m256 vop16 = _mm256_setzero_ps();2290__m256 vop24 = _mm256_setzero_ps();2291__m256 vop32 = _mm256_setzero_ps();2292__m256 vop40 = _mm256_setzero_ps();2293__m256 vop48 = _mm256_setzero_ps();2294__m256 vop56 = _mm256_setzero_ps();2295if (dataInd != offsets[rangeIndex] - offsets[0]) {2296return false;2297}2298int64_t end_offset = offsets[rangeIndex + 1];2299int64_t length = end_offset - offsets[rangeIndex];2300for (int64_t start = dataInd; dataInd < end_offset - offsets[0];2301++dataInd) {2302const int idx = indices[dataInd];2303if (idx < 0 || idx >= data_size) {2304return false;2305}2306float wgt = 1.f;2307if (weights) {2308wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];2309}2310__m256 vwgt = _mm256_set1_ps(wgt);2311const at::BFloat16* ip = &input[idx * fused_block_size];2312const int next_T0 = (dataInd < index_size - prefdist_T0)2313// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2314? (dataInd + prefdist_T0)2315// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2316: dataInd;2317const int idx_pref_T0 = indices[next_T0];2318if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {2319return false;2320}2321const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];2322vop0 = _mm256_fmadd_ps(2323vwgt,2324_mm256_castsi256_ps(_mm256_slli_epi32(2325_mm256_cvtepu16_epi32(_mm_loadu_si128(2326reinterpret_cast<const __m128i*>(ip + (0)))),232716)),2328vop0);2329_mm_prefetch(2330reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);2331vop8 = _mm256_fmadd_ps(2332vwgt,2333_mm256_castsi256_ps(_mm256_slli_epi32(2334_mm256_cvtepu16_epi32(_mm_loadu_si128(2335reinterpret_cast<const __m128i*>(ip + (8)))),233616)),2337vop8);2338// skip unnecessary prefetch of (&ip_next_T0[8])2339vop16 = _mm256_fmadd_ps(2340vwgt,2341_mm256_castsi256_ps(_mm256_slli_epi32(2342_mm256_cvtepu16_epi32(_mm_loadu_si128(2343reinterpret_cast<const __m128i*>(ip + (16)))),234416)),2345vop16);2346// skip unnecessary prefetch of (&ip_next_T0[16])2347vop24 = _mm256_fmadd_ps(2348vwgt,2349_mm256_castsi256_ps(_mm256_slli_epi32(2350_mm256_cvtepu16_epi32(_mm_loadu_si128(2351reinterpret_cast<const __m128i*>(ip + (24)))),235216)),2353vop24);2354// skip unnecessary prefetch of (&ip_next_T0[24])2355vop32 = _mm256_fmadd_ps(2356vwgt,2357_mm256_castsi256_ps(_mm256_slli_epi32(2358_mm256_cvtepu16_epi32(_mm_loadu_si128(2359reinterpret_cast<const __m128i*>(ip + (32)))),236016)),2361vop32);2362_mm_prefetch(2363reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);2364vop40 = _mm256_fmadd_ps(2365vwgt,2366_mm256_castsi256_ps(_mm256_slli_epi32(2367_mm256_cvtepu16_epi32(_mm_loadu_si128(2368reinterpret_cast<const __m128i*>(ip + (40)))),236916)),2370vop40);2371// skip unnecessary prefetch of (&ip_next_T0[40])2372vop48 = _mm256_fmadd_ps(2373vwgt,2374_mm256_castsi256_ps(_mm256_slli_epi32(2375_mm256_cvtepu16_epi32(_mm_loadu_si128(2376reinterpret_cast<const __m128i*>(ip + (48)))),237716)),2378vop48);2379// skip unnecessary prefetch of (&ip_next_T0[48])2380vop56 = _mm256_fmadd_ps(2381vwgt,2382_mm256_castsi256_ps(_mm256_slli_epi32(2383_mm256_cvtepu16_epi32(_mm_loadu_si128(2384reinterpret_cast<const __m128i*>(ip + (56)))),238516)),2386vop56);2387// skip unnecessary prefetch of (&ip_next_T0[56])2388}2389if (!normalize_by_lengths || length == 0) {2390_mm256_storeu_ps(&op[0], vop0);2391_mm256_storeu_ps(&op[8], vop8);2392_mm256_storeu_ps(&op[16], vop16);2393_mm256_storeu_ps(&op[24], vop24);2394_mm256_storeu_ps(&op[32], vop32);2395_mm256_storeu_ps(&op[40], vop40);2396_mm256_storeu_ps(&op[48], vop48);2397_mm256_storeu_ps(&op[56], vop56);2398} else {2399__m256 vlen_inv = _mm256_set1_ps(1.0f / length);2400_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));2401_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));2402_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));2403_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));2404_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));2405_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));2406_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));2407_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));2408}2409}2410} else if (block_size == 32) {2411// unrolling 4 times2412for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {2413float* op = &out[rangeIndex * block_size];2414__m256 vop0 = _mm256_setzero_ps();2415__m256 vop8 = _mm256_setzero_ps();2416__m256 vop16 = _mm256_setzero_ps();2417__m256 vop24 = _mm256_setzero_ps();2418if (dataInd != offsets[rangeIndex] - offsets[0]) {2419return false;2420}2421int64_t end_offset = offsets[rangeIndex + 1];2422int64_t length = end_offset - offsets[rangeIndex];2423for (int64_t start = dataInd; dataInd < end_offset - offsets[0];2424++dataInd) {2425const int idx = indices[dataInd];2426if (idx < 0 || idx >= data_size) {2427return false;2428}2429float wgt = 1.f;2430if (weights) {2431wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];2432}2433__m256 vwgt = _mm256_set1_ps(wgt);2434const at::BFloat16* ip = &input[idx * fused_block_size];2435const int next_T0 = (dataInd < index_size - prefdist_T0)2436// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2437? (dataInd + prefdist_T0)2438// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2439: dataInd;2440const int idx_pref_T0 = indices[next_T0];2441if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {2442return false;2443}2444const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];2445vop0 = _mm256_fmadd_ps(2446vwgt,2447_mm256_castsi256_ps(_mm256_slli_epi32(2448_mm256_cvtepu16_epi32(_mm_loadu_si128(2449reinterpret_cast<const __m128i*>(ip + (0)))),245016)),2451vop0);2452_mm_prefetch(2453reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);2454vop8 = _mm256_fmadd_ps(2455vwgt,2456_mm256_castsi256_ps(_mm256_slli_epi32(2457_mm256_cvtepu16_epi32(_mm_loadu_si128(2458reinterpret_cast<const __m128i*>(ip + (8)))),245916)),2460vop8);2461// skip unnecessary prefetch of (&ip_next_T0[8])2462vop16 = _mm256_fmadd_ps(2463vwgt,2464_mm256_castsi256_ps(_mm256_slli_epi32(2465_mm256_cvtepu16_epi32(_mm_loadu_si128(2466reinterpret_cast<const __m128i*>(ip + (16)))),246716)),2468vop16);2469// skip unnecessary prefetch of (&ip_next_T0[16])2470vop24 = _mm256_fmadd_ps(2471vwgt,2472_mm256_castsi256_ps(_mm256_slli_epi32(2473_mm256_cvtepu16_epi32(_mm_loadu_si128(2474reinterpret_cast<const __m128i*>(ip + (24)))),247516)),2476vop24);2477// skip unnecessary prefetch of (&ip_next_T0[24])2478}2479if (!normalize_by_lengths || length == 0) {2480_mm256_storeu_ps(&op[0], vop0);2481_mm256_storeu_ps(&op[8], vop8);2482_mm256_storeu_ps(&op[16], vop16);2483_mm256_storeu_ps(&op[24], vop24);2484} else {2485__m256 vlen_inv = _mm256_set1_ps(1.0f / length);2486_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));2487_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));2488_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));2489_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));2490}2491}2492} else if (block_size == 16) {2493// unrolling 2 times2494for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {2495float* op = &out[rangeIndex * block_size];2496__m256 vop0 = _mm256_setzero_ps();2497__m256 vop8 = _mm256_setzero_ps();2498if (dataInd != offsets[rangeIndex] - offsets[0]) {2499return false;2500}2501int64_t end_offset = offsets[rangeIndex + 1];2502int64_t length = end_offset - offsets[rangeIndex];2503for (int64_t start = dataInd; dataInd < end_offset - offsets[0];2504++dataInd) {2505const int idx = indices[dataInd];2506if (idx < 0 || idx >= data_size) {2507return false;2508}2509float wgt = 1.f;2510if (weights) {2511wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];2512}2513__m256 vwgt = _mm256_set1_ps(wgt);2514const at::BFloat16* ip = &input[idx * fused_block_size];2515const int next_T0 = (dataInd < index_size - prefdist_T0)2516// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2517? (dataInd + prefdist_T0)2518// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2519: dataInd;2520const int idx_pref_T0 = indices[next_T0];2521if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {2522return false;2523}2524const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];2525vop0 = _mm256_fmadd_ps(2526vwgt,2527_mm256_castsi256_ps(_mm256_slli_epi32(2528_mm256_cvtepu16_epi32(_mm_loadu_si128(2529reinterpret_cast<const __m128i*>(ip + (0)))),253016)),2531vop0);2532_mm_prefetch(2533reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);2534vop8 = _mm256_fmadd_ps(2535vwgt,2536_mm256_castsi256_ps(_mm256_slli_epi32(2537_mm256_cvtepu16_epi32(_mm_loadu_si128(2538reinterpret_cast<const __m128i*>(ip + (8)))),253916)),2540vop8);2541// skip unnecessary prefetch of (&ip_next_T0[8])2542}2543if (!normalize_by_lengths || length == 0) {2544_mm256_storeu_ps(&op[0], vop0);2545_mm256_storeu_ps(&op[8], vop8);2546} else {2547__m256 vlen_inv = _mm256_set1_ps(1.0f / length);2548_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));2549_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));2550}2551}2552} else {2553// generic code2554// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)2555alignas(64) at::BFloat16 vtmp1[8] = {0};2556for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {2557float* op = &out[rangeIndex * block_size];2558int64_t j = 0;2559for (; j + 8 <= block_size; j += 8) {2560_mm256_storeu_ps(op + j, _mm256_setzero_ps());2561}2562for (; j < block_size; j++) {2563op[j] = 0.0f;2564}2565if (dataInd != offsets[rangeIndex] - offsets[0]) {2566return false;2567}2568int64_t end_offset = offsets[rangeIndex + 1];2569int64_t length = end_offset - offsets[rangeIndex];2570for (int64_t start = dataInd; dataInd < end_offset - offsets[0];2571++dataInd) {2572const int idx = indices[dataInd];2573if (idx < 0 || idx >= data_size) {2574return false;2575}2576float wgt = 1.f;2577if (weights) {2578wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];2579}2580__m256 vwgt = _mm256_set1_ps(wgt);2581const at::BFloat16* ip = &input[idx * fused_block_size];2582const int next_T0 = (dataInd < index_size - prefdist_T0)2583// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2584? (dataInd + prefdist_T0)2585// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2586: dataInd;2587const int idx_pref_T0 = indices[next_T0];2588if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {2589return false;2590}2591const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];2592j = 0;2593for (; j + 8 <= block_size; j += 8) {2594_mm256_storeu_ps(2595&op[j],2596_mm256_fmadd_ps(2597vwgt,2598_mm256_castsi256_ps(_mm256_slli_epi32(2599_mm256_cvtepu16_epi32(_mm_loadu_si128(2600reinterpret_cast<const __m128i*>(&ip[j]))),260116)),2602_mm256_loadu_ps(&op[j])));2603_mm_prefetch(2604reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);2605}2606for (; j < block_size; j++) {2607vtmp1[0] = ip[j];2608__m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(2609_mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),261016));2611op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);2612}2613}2614if (normalize_by_lengths && length) {2615float len_inv = 1.0f / length;2616__m256 vlen_inv = _mm256_set1_ps(len_inv);2617j = 0;2618for (; j + 8 <= block_size; j += 8) {2619_mm256_storeu_ps(2620&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));2621}2622for (; j < block_size; j++) {2623op[j] = len_inv * op[j];2624}2625}2626}2627}2628return dataInd == index_size;2629}
2630bool EmbeddingLookupIdx_int32_t_bfloat16_float_false__avx2_fma(2631const int64_t block_size,2632const int64_t output_size,2633const int64_t index_size,2634const int64_t data_size,2635const at::BFloat16* input,2636const int* indices,2637const int* offsets,2638const float* weights,2639const float* scale_bias,2640bool normalize_by_lengths,2641float* out) {2642return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<false>(2643block_size,2644output_size,2645index_size,2646data_size,2647input,2648indices,2649offsets,2650weights,2651scale_bias,2652normalize_by_lengths,2653out);2654}
2655bool EmbeddingLookupIdx_int32_t_bfloat16_float_true__avx2_fma(2656const int64_t block_size,2657const int64_t output_size,2658const int64_t index_size,2659const int64_t data_size,2660const at::BFloat16* input,2661const int* indices,2662const int* offsets,2663const float* weights,2664const float* scale_bias,2665bool normalize_by_lengths,2666float* out) {2667return EmbeddingLookupIdx_int32_t_bfloat16_float__avx2_fma<true>(2668block_size,2669output_size,2670index_size,2671data_size,2672input,2673indices,2674offsets,2675weights,2676scale_bias,2677normalize_by_lengths,2678out);2679}
2680
2681template <bool IS_WEIGHT_POSITIONAL>2682static bool EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma(2683const int64_t block_size,2684const int64_t output_size,2685const int64_t index_size,2686const int64_t data_size,2687const at::BFloat16* input,2688const int64_t* indices,2689const int64_t* offsets,2690const float* weights,2691const float* scale_bias,2692bool normalize_by_lengths,2693float* out) {2694const int64_t prefdist_T0 = 16;2695// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2696const int64_t fused_block_size = block_size + 0;2697int64_t dataInd = 0;2698if (block_size == 128) {2699// unrolling 16 times2700for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {2701float* op = &out[rangeIndex * block_size];2702__m256 vop0 = _mm256_setzero_ps();2703__m256 vop8 = _mm256_setzero_ps();2704__m256 vop16 = _mm256_setzero_ps();2705__m256 vop24 = _mm256_setzero_ps();2706__m256 vop32 = _mm256_setzero_ps();2707__m256 vop40 = _mm256_setzero_ps();2708__m256 vop48 = _mm256_setzero_ps();2709__m256 vop56 = _mm256_setzero_ps();2710__m256 vop64 = _mm256_setzero_ps();2711__m256 vop72 = _mm256_setzero_ps();2712__m256 vop80 = _mm256_setzero_ps();2713__m256 vop88 = _mm256_setzero_ps();2714__m256 vop96 = _mm256_setzero_ps();2715__m256 vop104 = _mm256_setzero_ps();2716__m256 vop112 = _mm256_setzero_ps();2717__m256 vop120 = _mm256_setzero_ps();2718if (dataInd != offsets[rangeIndex] - offsets[0]) {2719return false;2720}2721int64_t end_offset = offsets[rangeIndex + 1];2722int64_t length = end_offset - offsets[rangeIndex];2723for (int64_t start = dataInd; dataInd < end_offset - offsets[0];2724++dataInd) {2725const int64_t idx = indices[dataInd];2726if (idx < 0 || idx >= data_size) {2727return false;2728}2729float wgt = 1.f;2730if (weights) {2731wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];2732}2733__m256 vwgt = _mm256_set1_ps(wgt);2734const at::BFloat16* ip = &input[idx * fused_block_size];2735const int64_t next_T0 = (dataInd < index_size - prefdist_T0)2736// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2737? (dataInd + prefdist_T0)2738// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2739: dataInd;2740const int64_t idx_pref_T0 = indices[next_T0];2741if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {2742return false;2743}2744const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];2745vop0 = _mm256_fmadd_ps(2746vwgt,2747_mm256_castsi256_ps(_mm256_slli_epi32(2748_mm256_cvtepu16_epi32(_mm_loadu_si128(2749reinterpret_cast<const __m128i*>(ip + (0)))),275016)),2751vop0);2752_mm_prefetch(2753reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);2754vop8 = _mm256_fmadd_ps(2755vwgt,2756_mm256_castsi256_ps(_mm256_slli_epi32(2757_mm256_cvtepu16_epi32(_mm_loadu_si128(2758reinterpret_cast<const __m128i*>(ip + (8)))),275916)),2760vop8);2761// skip unnecessary prefetch of (&ip_next_T0[8])2762vop16 = _mm256_fmadd_ps(2763vwgt,2764_mm256_castsi256_ps(_mm256_slli_epi32(2765_mm256_cvtepu16_epi32(_mm_loadu_si128(2766reinterpret_cast<const __m128i*>(ip + (16)))),276716)),2768vop16);2769// skip unnecessary prefetch of (&ip_next_T0[16])2770vop24 = _mm256_fmadd_ps(2771vwgt,2772_mm256_castsi256_ps(_mm256_slli_epi32(2773_mm256_cvtepu16_epi32(_mm_loadu_si128(2774reinterpret_cast<const __m128i*>(ip + (24)))),277516)),2776vop24);2777// skip unnecessary prefetch of (&ip_next_T0[24])2778vop32 = _mm256_fmadd_ps(2779vwgt,2780_mm256_castsi256_ps(_mm256_slli_epi32(2781_mm256_cvtepu16_epi32(_mm_loadu_si128(2782reinterpret_cast<const __m128i*>(ip + (32)))),278316)),2784vop32);2785_mm_prefetch(2786reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);2787vop40 = _mm256_fmadd_ps(2788vwgt,2789_mm256_castsi256_ps(_mm256_slli_epi32(2790_mm256_cvtepu16_epi32(_mm_loadu_si128(2791reinterpret_cast<const __m128i*>(ip + (40)))),279216)),2793vop40);2794// skip unnecessary prefetch of (&ip_next_T0[40])2795vop48 = _mm256_fmadd_ps(2796vwgt,2797_mm256_castsi256_ps(_mm256_slli_epi32(2798_mm256_cvtepu16_epi32(_mm_loadu_si128(2799reinterpret_cast<const __m128i*>(ip + (48)))),280016)),2801vop48);2802// skip unnecessary prefetch of (&ip_next_T0[48])2803vop56 = _mm256_fmadd_ps(2804vwgt,2805_mm256_castsi256_ps(_mm256_slli_epi32(2806_mm256_cvtepu16_epi32(_mm_loadu_si128(2807reinterpret_cast<const __m128i*>(ip + (56)))),280816)),2809vop56);2810// skip unnecessary prefetch of (&ip_next_T0[56])2811vop64 = _mm256_fmadd_ps(2812vwgt,2813_mm256_castsi256_ps(_mm256_slli_epi32(2814_mm256_cvtepu16_epi32(_mm_loadu_si128(2815reinterpret_cast<const __m128i*>(ip + (64)))),281616)),2817vop64);2818_mm_prefetch(2819reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);2820vop72 = _mm256_fmadd_ps(2821vwgt,2822_mm256_castsi256_ps(_mm256_slli_epi32(2823_mm256_cvtepu16_epi32(_mm_loadu_si128(2824reinterpret_cast<const __m128i*>(ip + (72)))),282516)),2826vop72);2827// skip unnecessary prefetch of (&ip_next_T0[72])2828vop80 = _mm256_fmadd_ps(2829vwgt,2830_mm256_castsi256_ps(_mm256_slli_epi32(2831_mm256_cvtepu16_epi32(_mm_loadu_si128(2832reinterpret_cast<const __m128i*>(ip + (80)))),283316)),2834vop80);2835// skip unnecessary prefetch of (&ip_next_T0[80])2836vop88 = _mm256_fmadd_ps(2837vwgt,2838_mm256_castsi256_ps(_mm256_slli_epi32(2839_mm256_cvtepu16_epi32(_mm_loadu_si128(2840reinterpret_cast<const __m128i*>(ip + (88)))),284116)),2842vop88);2843// skip unnecessary prefetch of (&ip_next_T0[88])2844vop96 = _mm256_fmadd_ps(2845vwgt,2846_mm256_castsi256_ps(_mm256_slli_epi32(2847_mm256_cvtepu16_epi32(_mm_loadu_si128(2848reinterpret_cast<const __m128i*>(ip + (96)))),284916)),2850vop96);2851_mm_prefetch(2852reinterpret_cast<const char*>(&ip_next_T0[96]), _MM_HINT_T0);2853vop104 = _mm256_fmadd_ps(2854vwgt,2855_mm256_castsi256_ps(_mm256_slli_epi32(2856_mm256_cvtepu16_epi32(_mm_loadu_si128(2857reinterpret_cast<const __m128i*>(ip + (104)))),285816)),2859vop104);2860// skip unnecessary prefetch of (&ip_next_T0[104])2861vop112 = _mm256_fmadd_ps(2862vwgt,2863_mm256_castsi256_ps(_mm256_slli_epi32(2864_mm256_cvtepu16_epi32(_mm_loadu_si128(2865reinterpret_cast<const __m128i*>(ip + (112)))),286616)),2867vop112);2868// skip unnecessary prefetch of (&ip_next_T0[112])2869vop120 = _mm256_fmadd_ps(2870vwgt,2871_mm256_castsi256_ps(_mm256_slli_epi32(2872_mm256_cvtepu16_epi32(_mm_loadu_si128(2873reinterpret_cast<const __m128i*>(ip + (120)))),287416)),2875vop120);2876// skip unnecessary prefetch of (&ip_next_T0[120])2877}2878if (!normalize_by_lengths || length == 0) {2879_mm256_storeu_ps(&op[0], vop0);2880_mm256_storeu_ps(&op[8], vop8);2881_mm256_storeu_ps(&op[16], vop16);2882_mm256_storeu_ps(&op[24], vop24);2883_mm256_storeu_ps(&op[32], vop32);2884_mm256_storeu_ps(&op[40], vop40);2885_mm256_storeu_ps(&op[48], vop48);2886_mm256_storeu_ps(&op[56], vop56);2887_mm256_storeu_ps(&op[64], vop64);2888_mm256_storeu_ps(&op[72], vop72);2889_mm256_storeu_ps(&op[80], vop80);2890_mm256_storeu_ps(&op[88], vop88);2891_mm256_storeu_ps(&op[96], vop96);2892_mm256_storeu_ps(&op[104], vop104);2893_mm256_storeu_ps(&op[112], vop112);2894_mm256_storeu_ps(&op[120], vop120);2895} else {2896__m256 vlen_inv = _mm256_set1_ps(1.0f / length);2897_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));2898_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));2899_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));2900_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));2901_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));2902_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));2903_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));2904_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));2905_mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));2906_mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));2907_mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));2908_mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));2909_mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));2910_mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));2911_mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));2912_mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));2913}2914}2915} else if (block_size == 64) {2916// unrolling 8 times2917for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {2918float* op = &out[rangeIndex * block_size];2919__m256 vop0 = _mm256_setzero_ps();2920__m256 vop8 = _mm256_setzero_ps();2921__m256 vop16 = _mm256_setzero_ps();2922__m256 vop24 = _mm256_setzero_ps();2923__m256 vop32 = _mm256_setzero_ps();2924__m256 vop40 = _mm256_setzero_ps();2925__m256 vop48 = _mm256_setzero_ps();2926__m256 vop56 = _mm256_setzero_ps();2927if (dataInd != offsets[rangeIndex] - offsets[0]) {2928return false;2929}2930int64_t end_offset = offsets[rangeIndex + 1];2931int64_t length = end_offset - offsets[rangeIndex];2932for (int64_t start = dataInd; dataInd < end_offset - offsets[0];2933++dataInd) {2934const int64_t idx = indices[dataInd];2935if (idx < 0 || idx >= data_size) {2936return false;2937}2938float wgt = 1.f;2939if (weights) {2940wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];2941}2942__m256 vwgt = _mm256_set1_ps(wgt);2943const at::BFloat16* ip = &input[idx * fused_block_size];2944const int64_t next_T0 = (dataInd < index_size - prefdist_T0)2945// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2946? (dataInd + prefdist_T0)2947// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)2948: dataInd;2949const int64_t idx_pref_T0 = indices[next_T0];2950if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {2951return false;2952}2953const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];2954vop0 = _mm256_fmadd_ps(2955vwgt,2956_mm256_castsi256_ps(_mm256_slli_epi32(2957_mm256_cvtepu16_epi32(_mm_loadu_si128(2958reinterpret_cast<const __m128i*>(ip + (0)))),295916)),2960vop0);2961_mm_prefetch(2962reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);2963vop8 = _mm256_fmadd_ps(2964vwgt,2965_mm256_castsi256_ps(_mm256_slli_epi32(2966_mm256_cvtepu16_epi32(_mm_loadu_si128(2967reinterpret_cast<const __m128i*>(ip + (8)))),296816)),2969vop8);2970// skip unnecessary prefetch of (&ip_next_T0[8])2971vop16 = _mm256_fmadd_ps(2972vwgt,2973_mm256_castsi256_ps(_mm256_slli_epi32(2974_mm256_cvtepu16_epi32(_mm_loadu_si128(2975reinterpret_cast<const __m128i*>(ip + (16)))),297616)),2977vop16);2978// skip unnecessary prefetch of (&ip_next_T0[16])2979vop24 = _mm256_fmadd_ps(2980vwgt,2981_mm256_castsi256_ps(_mm256_slli_epi32(2982_mm256_cvtepu16_epi32(_mm_loadu_si128(2983reinterpret_cast<const __m128i*>(ip + (24)))),298416)),2985vop24);2986// skip unnecessary prefetch of (&ip_next_T0[24])2987vop32 = _mm256_fmadd_ps(2988vwgt,2989_mm256_castsi256_ps(_mm256_slli_epi32(2990_mm256_cvtepu16_epi32(_mm_loadu_si128(2991reinterpret_cast<const __m128i*>(ip + (32)))),299216)),2993vop32);2994_mm_prefetch(2995reinterpret_cast<const char*>(&ip_next_T0[32]), _MM_HINT_T0);2996vop40 = _mm256_fmadd_ps(2997vwgt,2998_mm256_castsi256_ps(_mm256_slli_epi32(2999_mm256_cvtepu16_epi32(_mm_loadu_si128(3000reinterpret_cast<const __m128i*>(ip + (40)))),300116)),3002vop40);3003// skip unnecessary prefetch of (&ip_next_T0[40])3004vop48 = _mm256_fmadd_ps(3005vwgt,3006_mm256_castsi256_ps(_mm256_slli_epi32(3007_mm256_cvtepu16_epi32(_mm_loadu_si128(3008reinterpret_cast<const __m128i*>(ip + (48)))),300916)),3010vop48);3011// skip unnecessary prefetch of (&ip_next_T0[48])3012vop56 = _mm256_fmadd_ps(3013vwgt,3014_mm256_castsi256_ps(_mm256_slli_epi32(3015_mm256_cvtepu16_epi32(_mm_loadu_si128(3016reinterpret_cast<const __m128i*>(ip + (56)))),301716)),3018vop56);3019// skip unnecessary prefetch of (&ip_next_T0[56])3020}3021if (!normalize_by_lengths || length == 0) {3022_mm256_storeu_ps(&op[0], vop0);3023_mm256_storeu_ps(&op[8], vop8);3024_mm256_storeu_ps(&op[16], vop16);3025_mm256_storeu_ps(&op[24], vop24);3026_mm256_storeu_ps(&op[32], vop32);3027_mm256_storeu_ps(&op[40], vop40);3028_mm256_storeu_ps(&op[48], vop48);3029_mm256_storeu_ps(&op[56], vop56);3030} else {3031__m256 vlen_inv = _mm256_set1_ps(1.0f / length);3032_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));3033_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));3034_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));3035_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));3036_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));3037_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));3038_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));3039_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));3040}3041}3042} else if (block_size == 32) {3043// unrolling 4 times3044for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {3045float* op = &out[rangeIndex * block_size];3046__m256 vop0 = _mm256_setzero_ps();3047__m256 vop8 = _mm256_setzero_ps();3048__m256 vop16 = _mm256_setzero_ps();3049__m256 vop24 = _mm256_setzero_ps();3050if (dataInd != offsets[rangeIndex] - offsets[0]) {3051return false;3052}3053int64_t end_offset = offsets[rangeIndex + 1];3054int64_t length = end_offset - offsets[rangeIndex];3055for (int64_t start = dataInd; dataInd < end_offset - offsets[0];3056++dataInd) {3057const int64_t idx = indices[dataInd];3058if (idx < 0 || idx >= data_size) {3059return false;3060}3061float wgt = 1.f;3062if (weights) {3063wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];3064}3065__m256 vwgt = _mm256_set1_ps(wgt);3066const at::BFloat16* ip = &input[idx * fused_block_size];3067const int64_t next_T0 = (dataInd < index_size - prefdist_T0)3068// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3069? (dataInd + prefdist_T0)3070// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3071: dataInd;3072const int64_t idx_pref_T0 = indices[next_T0];3073if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {3074return false;3075}3076const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];3077vop0 = _mm256_fmadd_ps(3078vwgt,3079_mm256_castsi256_ps(_mm256_slli_epi32(3080_mm256_cvtepu16_epi32(_mm_loadu_si128(3081reinterpret_cast<const __m128i*>(ip + (0)))),308216)),3083vop0);3084_mm_prefetch(3085reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);3086vop8 = _mm256_fmadd_ps(3087vwgt,3088_mm256_castsi256_ps(_mm256_slli_epi32(3089_mm256_cvtepu16_epi32(_mm_loadu_si128(3090reinterpret_cast<const __m128i*>(ip + (8)))),309116)),3092vop8);3093// skip unnecessary prefetch of (&ip_next_T0[8])3094vop16 = _mm256_fmadd_ps(3095vwgt,3096_mm256_castsi256_ps(_mm256_slli_epi32(3097_mm256_cvtepu16_epi32(_mm_loadu_si128(3098reinterpret_cast<const __m128i*>(ip + (16)))),309916)),3100vop16);3101// skip unnecessary prefetch of (&ip_next_T0[16])3102vop24 = _mm256_fmadd_ps(3103vwgt,3104_mm256_castsi256_ps(_mm256_slli_epi32(3105_mm256_cvtepu16_epi32(_mm_loadu_si128(3106reinterpret_cast<const __m128i*>(ip + (24)))),310716)),3108vop24);3109// skip unnecessary prefetch of (&ip_next_T0[24])3110}3111if (!normalize_by_lengths || length == 0) {3112_mm256_storeu_ps(&op[0], vop0);3113_mm256_storeu_ps(&op[8], vop8);3114_mm256_storeu_ps(&op[16], vop16);3115_mm256_storeu_ps(&op[24], vop24);3116} else {3117__m256 vlen_inv = _mm256_set1_ps(1.0f / length);3118_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));3119_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));3120_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));3121_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));3122}3123}3124} else if (block_size == 16) {3125// unrolling 2 times3126for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {3127float* op = &out[rangeIndex * block_size];3128__m256 vop0 = _mm256_setzero_ps();3129__m256 vop8 = _mm256_setzero_ps();3130if (dataInd != offsets[rangeIndex] - offsets[0]) {3131return false;3132}3133int64_t end_offset = offsets[rangeIndex + 1];3134int64_t length = end_offset - offsets[rangeIndex];3135for (int64_t start = dataInd; dataInd < end_offset - offsets[0];3136++dataInd) {3137const int64_t idx = indices[dataInd];3138if (idx < 0 || idx >= data_size) {3139return false;3140}3141float wgt = 1.f;3142if (weights) {3143wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];3144}3145__m256 vwgt = _mm256_set1_ps(wgt);3146const at::BFloat16* ip = &input[idx * fused_block_size];3147const int64_t next_T0 = (dataInd < index_size - prefdist_T0)3148// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3149? (dataInd + prefdist_T0)3150// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3151: dataInd;3152const int64_t idx_pref_T0 = indices[next_T0];3153if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {3154return false;3155}3156const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];3157vop0 = _mm256_fmadd_ps(3158vwgt,3159_mm256_castsi256_ps(_mm256_slli_epi32(3160_mm256_cvtepu16_epi32(_mm_loadu_si128(3161reinterpret_cast<const __m128i*>(ip + (0)))),316216)),3163vop0);3164_mm_prefetch(3165reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);3166vop8 = _mm256_fmadd_ps(3167vwgt,3168_mm256_castsi256_ps(_mm256_slli_epi32(3169_mm256_cvtepu16_epi32(_mm_loadu_si128(3170reinterpret_cast<const __m128i*>(ip + (8)))),317116)),3172vop8);3173// skip unnecessary prefetch of (&ip_next_T0[8])3174}3175if (!normalize_by_lengths || length == 0) {3176_mm256_storeu_ps(&op[0], vop0);3177_mm256_storeu_ps(&op[8], vop8);3178} else {3179__m256 vlen_inv = _mm256_set1_ps(1.0f / length);3180_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));3181_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));3182}3183}3184} else {3185// generic code3186// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)3187alignas(64) at::BFloat16 vtmp1[8] = {0};3188for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {3189float* op = &out[rangeIndex * block_size];3190int64_t j = 0;3191for (; j + 8 <= block_size; j += 8) {3192_mm256_storeu_ps(op + j, _mm256_setzero_ps());3193}3194for (; j < block_size; j++) {3195op[j] = 0.0f;3196}3197if (dataInd != offsets[rangeIndex] - offsets[0]) {3198return false;3199}3200int64_t end_offset = offsets[rangeIndex + 1];3201int64_t length = end_offset - offsets[rangeIndex];3202for (int64_t start = dataInd; dataInd < end_offset - offsets[0];3203++dataInd) {3204const int64_t idx = indices[dataInd];3205if (idx < 0 || idx >= data_size) {3206return false;3207}3208float wgt = 1.f;3209if (weights) {3210wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];3211}3212__m256 vwgt = _mm256_set1_ps(wgt);3213const at::BFloat16* ip = &input[idx * fused_block_size];3214const int64_t next_T0 = (dataInd < index_size - prefdist_T0)3215// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3216? (dataInd + prefdist_T0)3217// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3218: dataInd;3219const int64_t idx_pref_T0 = indices[next_T0];3220if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {3221return false;3222}3223const at::BFloat16* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];3224j = 0;3225for (; j + 8 <= block_size; j += 8) {3226_mm256_storeu_ps(3227&op[j],3228_mm256_fmadd_ps(3229vwgt,3230_mm256_castsi256_ps(_mm256_slli_epi32(3231_mm256_cvtepu16_epi32(_mm_loadu_si128(3232reinterpret_cast<const __m128i*>(&ip[j]))),323316)),3234_mm256_loadu_ps(&op[j])));3235_mm_prefetch(3236reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);3237}3238for (; j < block_size; j++) {3239vtmp1[0] = ip[j];3240__m256 vtmp2 = _mm256_castsi256_ps(_mm256_slli_epi32(3241_mm256_cvtepu16_epi32(*(reinterpret_cast<const __m128i*>(vtmp1))),324216));3243op[j] = std::fma(wgt, ((float*)(&vtmp2))[0], op[j]);3244}3245}3246if (normalize_by_lengths && length) {3247float len_inv = 1.0f / length;3248__m256 vlen_inv = _mm256_set1_ps(len_inv);3249j = 0;3250for (; j + 8 <= block_size; j += 8) {3251_mm256_storeu_ps(3252&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));3253}3254for (; j < block_size; j++) {3255op[j] = len_inv * op[j];3256}3257}3258}3259}3260return dataInd == index_size;3261}
3262bool EmbeddingLookupIdx_int64_t_bfloat16_float_false__avx2_fma(3263const int64_t block_size,3264const int64_t output_size,3265const int64_t index_size,3266const int64_t data_size,3267const at::BFloat16* input,3268const int64_t* indices,3269const int64_t* offsets,3270const float* weights,3271const float* scale_bias,3272bool normalize_by_lengths,3273float* out) {3274return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<false>(3275block_size,3276output_size,3277index_size,3278data_size,3279input,3280indices,3281offsets,3282weights,3283scale_bias,3284normalize_by_lengths,3285out);3286}
3287bool EmbeddingLookupIdx_int64_t_bfloat16_float_true__avx2_fma(3288const int64_t block_size,3289const int64_t output_size,3290const int64_t index_size,3291const int64_t data_size,3292const at::BFloat16* input,3293const int64_t* indices,3294const int64_t* offsets,3295const float* weights,3296const float* scale_bias,3297bool normalize_by_lengths,3298float* out) {3299return EmbeddingLookupIdx_int64_t_bfloat16_float__avx2_fma<true>(3300block_size,3301output_size,3302index_size,3303data_size,3304input,3305indices,3306offsets,3307weights,3308scale_bias,3309normalize_by_lengths,3310out);3311}
3312
3313template <bool IS_WEIGHT_POSITIONAL>3314static bool EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma(3315const int64_t block_size,3316const int64_t output_size,3317const int64_t index_size,3318const int64_t data_size,3319const uint8_t* input,3320const int* indices,3321const int* offsets,3322const float* weights,3323const float* scale_bias,3324bool normalize_by_lengths,3325float* out) {3326const int prefdist_T0 = 16;3327// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3328const int fused_block_size = block_size + 0;3329int64_t dataInd = 0;3330if (block_size == 128) {3331// unrolling 16 times3332for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {3333float* op = &out[rangeIndex * block_size];3334__m256 vop0 = _mm256_setzero_ps();3335__m256 vop8 = _mm256_setzero_ps();3336__m256 vop16 = _mm256_setzero_ps();3337__m256 vop24 = _mm256_setzero_ps();3338__m256 vop32 = _mm256_setzero_ps();3339__m256 vop40 = _mm256_setzero_ps();3340__m256 vop48 = _mm256_setzero_ps();3341__m256 vop56 = _mm256_setzero_ps();3342__m256 vop64 = _mm256_setzero_ps();3343__m256 vop72 = _mm256_setzero_ps();3344__m256 vop80 = _mm256_setzero_ps();3345__m256 vop88 = _mm256_setzero_ps();3346__m256 vop96 = _mm256_setzero_ps();3347__m256 vop104 = _mm256_setzero_ps();3348__m256 vop112 = _mm256_setzero_ps();3349__m256 vop120 = _mm256_setzero_ps();3350if (dataInd != offsets[rangeIndex] - offsets[0]) {3351return false;3352}3353int64_t end_offset = offsets[rangeIndex + 1];3354int64_t length = end_offset - offsets[rangeIndex];3355for (int64_t start = dataInd; dataInd < end_offset - offsets[0];3356++dataInd) {3357const int idx = indices[dataInd];3358if (idx < 0 || idx >= data_size) {3359return false;3360}3361float wgt = 1.f;3362// NOLINTNEXTLINE(cppcoreguidelines-init-variables)3363float bio;3364if (weights) {3365wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];3366}3367bio = wgt * scale_bias[2 * idx + 1];3368wgt = wgt * scale_bias[2 * idx];3369__m256 vbio = _mm256_set1_ps(bio);3370__m256 vwgt = _mm256_set1_ps(wgt);3371const uint8_t* ip = &input[idx * fused_block_size];3372const int next_T0 = (dataInd < index_size - prefdist_T0)3373// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3374? (dataInd + prefdist_T0)3375// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3376: dataInd;3377const int idx_pref_T0 = indices[next_T0];3378if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {3379return false;3380}3381const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];3382vop0 = _mm256_fmadd_ps(3383vwgt,3384_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3385_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),3386_mm256_add_ps(vop0, vbio));3387_mm_prefetch(3388reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);3389vop8 = _mm256_fmadd_ps(3390vwgt,3391_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3392_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),3393_mm256_add_ps(vop8, vbio));3394// skip unnecessary prefetch of (&ip_next_T0[8])3395vop16 = _mm256_fmadd_ps(3396vwgt,3397_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3398_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),3399_mm256_add_ps(vop16, vbio));3400// skip unnecessary prefetch of (&ip_next_T0[16])3401vop24 = _mm256_fmadd_ps(3402vwgt,3403_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3404_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),3405_mm256_add_ps(vop24, vbio));3406// skip unnecessary prefetch of (&ip_next_T0[24])3407vop32 = _mm256_fmadd_ps(3408vwgt,3409_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3410_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),3411_mm256_add_ps(vop32, vbio));3412// skip unnecessary prefetch of (&ip_next_T0[32])3413vop40 = _mm256_fmadd_ps(3414vwgt,3415_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3416_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),3417_mm256_add_ps(vop40, vbio));3418// skip unnecessary prefetch of (&ip_next_T0[40])3419vop48 = _mm256_fmadd_ps(3420vwgt,3421_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3422_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),3423_mm256_add_ps(vop48, vbio));3424// skip unnecessary prefetch of (&ip_next_T0[48])3425vop56 = _mm256_fmadd_ps(3426vwgt,3427_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3428_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),3429_mm256_add_ps(vop56, vbio));3430// skip unnecessary prefetch of (&ip_next_T0[56])3431vop64 = _mm256_fmadd_ps(3432vwgt,3433_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3434_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),3435_mm256_add_ps(vop64, vbio));3436_mm_prefetch(3437reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);3438vop72 = _mm256_fmadd_ps(3439vwgt,3440_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3441_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),3442_mm256_add_ps(vop72, vbio));3443// skip unnecessary prefetch of (&ip_next_T0[72])3444vop80 = _mm256_fmadd_ps(3445vwgt,3446_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3447_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),3448_mm256_add_ps(vop80, vbio));3449// skip unnecessary prefetch of (&ip_next_T0[80])3450vop88 = _mm256_fmadd_ps(3451vwgt,3452_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3453_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),3454_mm256_add_ps(vop88, vbio));3455// skip unnecessary prefetch of (&ip_next_T0[88])3456vop96 = _mm256_fmadd_ps(3457vwgt,3458_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3459_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),3460_mm256_add_ps(vop96, vbio));3461// skip unnecessary prefetch of (&ip_next_T0[96])3462vop104 = _mm256_fmadd_ps(3463vwgt,3464_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3465_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),3466_mm256_add_ps(vop104, vbio));3467// skip unnecessary prefetch of (&ip_next_T0[104])3468vop112 = _mm256_fmadd_ps(3469vwgt,3470_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3471_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),3472_mm256_add_ps(vop112, vbio));3473// skip unnecessary prefetch of (&ip_next_T0[112])3474vop120 = _mm256_fmadd_ps(3475vwgt,3476_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3477_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),3478_mm256_add_ps(vop120, vbio));3479// skip unnecessary prefetch of (&ip_next_T0[120])3480}3481if (!normalize_by_lengths || length == 0) {3482_mm256_storeu_ps(&op[0], vop0);3483_mm256_storeu_ps(&op[8], vop8);3484_mm256_storeu_ps(&op[16], vop16);3485_mm256_storeu_ps(&op[24], vop24);3486_mm256_storeu_ps(&op[32], vop32);3487_mm256_storeu_ps(&op[40], vop40);3488_mm256_storeu_ps(&op[48], vop48);3489_mm256_storeu_ps(&op[56], vop56);3490_mm256_storeu_ps(&op[64], vop64);3491_mm256_storeu_ps(&op[72], vop72);3492_mm256_storeu_ps(&op[80], vop80);3493_mm256_storeu_ps(&op[88], vop88);3494_mm256_storeu_ps(&op[96], vop96);3495_mm256_storeu_ps(&op[104], vop104);3496_mm256_storeu_ps(&op[112], vop112);3497_mm256_storeu_ps(&op[120], vop120);3498} else {3499__m256 vlen_inv = _mm256_set1_ps(1.0f / length);3500_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));3501_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));3502_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));3503_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));3504_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));3505_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));3506_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));3507_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));3508_mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));3509_mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));3510_mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));3511_mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));3512_mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));3513_mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));3514_mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));3515_mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));3516}3517}3518} else if (block_size == 64) {3519// unrolling 8 times3520for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {3521float* op = &out[rangeIndex * block_size];3522__m256 vop0 = _mm256_setzero_ps();3523__m256 vop8 = _mm256_setzero_ps();3524__m256 vop16 = _mm256_setzero_ps();3525__m256 vop24 = _mm256_setzero_ps();3526__m256 vop32 = _mm256_setzero_ps();3527__m256 vop40 = _mm256_setzero_ps();3528__m256 vop48 = _mm256_setzero_ps();3529__m256 vop56 = _mm256_setzero_ps();3530if (dataInd != offsets[rangeIndex] - offsets[0]) {3531return false;3532}3533int64_t end_offset = offsets[rangeIndex + 1];3534int64_t length = end_offset - offsets[rangeIndex];3535for (int64_t start = dataInd; dataInd < end_offset - offsets[0];3536++dataInd) {3537const int idx = indices[dataInd];3538if (idx < 0 || idx >= data_size) {3539return false;3540}3541float wgt = 1.f;3542// NOLINTNEXTLINE(cppcoreguidelines-init-variables)3543float bio;3544if (weights) {3545wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];3546}3547bio = wgt * scale_bias[2 * idx + 1];3548wgt = wgt * scale_bias[2 * idx];3549__m256 vbio = _mm256_set1_ps(bio);3550__m256 vwgt = _mm256_set1_ps(wgt);3551const uint8_t* ip = &input[idx * fused_block_size];3552const int next_T0 = (dataInd < index_size - prefdist_T0)3553// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3554? (dataInd + prefdist_T0)3555// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3556: dataInd;3557const int idx_pref_T0 = indices[next_T0];3558if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {3559return false;3560}3561const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];3562vop0 = _mm256_fmadd_ps(3563vwgt,3564_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3565_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),3566_mm256_add_ps(vop0, vbio));3567_mm_prefetch(3568reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);3569vop8 = _mm256_fmadd_ps(3570vwgt,3571_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3572_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),3573_mm256_add_ps(vop8, vbio));3574// skip unnecessary prefetch of (&ip_next_T0[8])3575vop16 = _mm256_fmadd_ps(3576vwgt,3577_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3578_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),3579_mm256_add_ps(vop16, vbio));3580// skip unnecessary prefetch of (&ip_next_T0[16])3581vop24 = _mm256_fmadd_ps(3582vwgt,3583_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3584_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),3585_mm256_add_ps(vop24, vbio));3586// skip unnecessary prefetch of (&ip_next_T0[24])3587vop32 = _mm256_fmadd_ps(3588vwgt,3589_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3590_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),3591_mm256_add_ps(vop32, vbio));3592// skip unnecessary prefetch of (&ip_next_T0[32])3593vop40 = _mm256_fmadd_ps(3594vwgt,3595_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3596_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),3597_mm256_add_ps(vop40, vbio));3598// skip unnecessary prefetch of (&ip_next_T0[40])3599vop48 = _mm256_fmadd_ps(3600vwgt,3601_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3602_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),3603_mm256_add_ps(vop48, vbio));3604// skip unnecessary prefetch of (&ip_next_T0[48])3605vop56 = _mm256_fmadd_ps(3606vwgt,3607_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3608_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),3609_mm256_add_ps(vop56, vbio));3610// skip unnecessary prefetch of (&ip_next_T0[56])3611}3612if (!normalize_by_lengths || length == 0) {3613_mm256_storeu_ps(&op[0], vop0);3614_mm256_storeu_ps(&op[8], vop8);3615_mm256_storeu_ps(&op[16], vop16);3616_mm256_storeu_ps(&op[24], vop24);3617_mm256_storeu_ps(&op[32], vop32);3618_mm256_storeu_ps(&op[40], vop40);3619_mm256_storeu_ps(&op[48], vop48);3620_mm256_storeu_ps(&op[56], vop56);3621} else {3622__m256 vlen_inv = _mm256_set1_ps(1.0f / length);3623_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));3624_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));3625_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));3626_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));3627_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));3628_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));3629_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));3630_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));3631}3632}3633} else if (block_size == 32) {3634// unrolling 4 times3635for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {3636float* op = &out[rangeIndex * block_size];3637__m256 vop0 = _mm256_setzero_ps();3638__m256 vop8 = _mm256_setzero_ps();3639__m256 vop16 = _mm256_setzero_ps();3640__m256 vop24 = _mm256_setzero_ps();3641if (dataInd != offsets[rangeIndex] - offsets[0]) {3642return false;3643}3644int64_t end_offset = offsets[rangeIndex + 1];3645int64_t length = end_offset - offsets[rangeIndex];3646for (int64_t start = dataInd; dataInd < end_offset - offsets[0];3647++dataInd) {3648const int idx = indices[dataInd];3649if (idx < 0 || idx >= data_size) {3650return false;3651}3652float wgt = 1.f;3653// NOLINTNEXTLINE(cppcoreguidelines-init-variables)3654float bio;3655if (weights) {3656wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];3657}3658bio = wgt * scale_bias[2 * idx + 1];3659wgt = wgt * scale_bias[2 * idx];3660__m256 vbio = _mm256_set1_ps(bio);3661__m256 vwgt = _mm256_set1_ps(wgt);3662const uint8_t* ip = &input[idx * fused_block_size];3663const int next_T0 = (dataInd < index_size - prefdist_T0)3664// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3665? (dataInd + prefdist_T0)3666// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3667: dataInd;3668const int idx_pref_T0 = indices[next_T0];3669if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {3670return false;3671}3672const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];3673vop0 = _mm256_fmadd_ps(3674vwgt,3675_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3676_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),3677_mm256_add_ps(vop0, vbio));3678_mm_prefetch(3679reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);3680vop8 = _mm256_fmadd_ps(3681vwgt,3682_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3683_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),3684_mm256_add_ps(vop8, vbio));3685// skip unnecessary prefetch of (&ip_next_T0[8])3686vop16 = _mm256_fmadd_ps(3687vwgt,3688_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3689_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),3690_mm256_add_ps(vop16, vbio));3691// skip unnecessary prefetch of (&ip_next_T0[16])3692vop24 = _mm256_fmadd_ps(3693vwgt,3694_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3695_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),3696_mm256_add_ps(vop24, vbio));3697// skip unnecessary prefetch of (&ip_next_T0[24])3698}3699if (!normalize_by_lengths || length == 0) {3700_mm256_storeu_ps(&op[0], vop0);3701_mm256_storeu_ps(&op[8], vop8);3702_mm256_storeu_ps(&op[16], vop16);3703_mm256_storeu_ps(&op[24], vop24);3704} else {3705__m256 vlen_inv = _mm256_set1_ps(1.0f / length);3706_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));3707_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));3708_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));3709_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));3710}3711}3712} else if (block_size == 16) {3713// unrolling 2 times3714for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {3715float* op = &out[rangeIndex * block_size];3716__m256 vop0 = _mm256_setzero_ps();3717__m256 vop8 = _mm256_setzero_ps();3718if (dataInd != offsets[rangeIndex] - offsets[0]) {3719return false;3720}3721int64_t end_offset = offsets[rangeIndex + 1];3722int64_t length = end_offset - offsets[rangeIndex];3723for (int64_t start = dataInd; dataInd < end_offset - offsets[0];3724++dataInd) {3725const int idx = indices[dataInd];3726if (idx < 0 || idx >= data_size) {3727return false;3728}3729float wgt = 1.f;3730// NOLINTNEXTLINE(cppcoreguidelines-init-variables)3731float bio;3732if (weights) {3733wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];3734}3735bio = wgt * scale_bias[2 * idx + 1];3736wgt = wgt * scale_bias[2 * idx];3737__m256 vbio = _mm256_set1_ps(bio);3738__m256 vwgt = _mm256_set1_ps(wgt);3739const uint8_t* ip = &input[idx * fused_block_size];3740const int next_T0 = (dataInd < index_size - prefdist_T0)3741// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3742? (dataInd + prefdist_T0)3743// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3744: dataInd;3745const int idx_pref_T0 = indices[next_T0];3746if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {3747return false;3748}3749const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];3750vop0 = _mm256_fmadd_ps(3751vwgt,3752_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3753_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),3754_mm256_add_ps(vop0, vbio));3755_mm_prefetch(3756reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);3757vop8 = _mm256_fmadd_ps(3758vwgt,3759_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3760_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),3761_mm256_add_ps(vop8, vbio));3762// skip unnecessary prefetch of (&ip_next_T0[8])3763}3764if (!normalize_by_lengths || length == 0) {3765_mm256_storeu_ps(&op[0], vop0);3766_mm256_storeu_ps(&op[8], vop8);3767} else {3768__m256 vlen_inv = _mm256_set1_ps(1.0f / length);3769_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));3770_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));3771}3772}3773} else {3774// generic code3775// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)3776for (int rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {3777float* op = &out[rangeIndex * block_size];3778int64_t j = 0;3779for (; j + 8 <= block_size; j += 8) {3780_mm256_storeu_ps(op + j, _mm256_setzero_ps());3781}3782for (; j < block_size; j++) {3783op[j] = 0.0f;3784}3785if (dataInd != offsets[rangeIndex] - offsets[0]) {3786return false;3787}3788int64_t end_offset = offsets[rangeIndex + 1];3789int64_t length = end_offset - offsets[rangeIndex];3790for (int64_t start = dataInd; dataInd < end_offset - offsets[0];3791++dataInd) {3792const int idx = indices[dataInd];3793if (idx < 0 || idx >= data_size) {3794return false;3795}3796float wgt = 1.f;3797// NOLINTNEXTLINE(cppcoreguidelines-init-variables)3798float bio;3799if (weights) {3800wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];3801}3802bio = wgt * scale_bias[2 * idx + 1];3803wgt = wgt * scale_bias[2 * idx];3804__m256 vbio = _mm256_set1_ps(bio);3805__m256 vwgt = _mm256_set1_ps(wgt);3806const uint8_t* ip = &input[idx * fused_block_size];3807const int next_T0 = (dataInd < index_size - prefdist_T0)3808// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3809? (dataInd + prefdist_T0)3810// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3811: dataInd;3812const int idx_pref_T0 = indices[next_T0];3813if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {3814return false;3815}3816const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];3817j = 0;3818for (; j + 8 <= block_size; j += 8) {3819_mm256_storeu_ps(3820&op[j],3821_mm256_fmadd_ps(3822vwgt,3823_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(3824reinterpret_cast<const __m128i*>(&ip[j])))),3825_mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));3826_mm_prefetch(3827reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);3828}3829for (; j < block_size; j++) {3830op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);3831}3832}3833if (normalize_by_lengths && length) {3834float len_inv = 1.0f / length;3835__m256 vlen_inv = _mm256_set1_ps(len_inv);3836j = 0;3837for (; j + 8 <= block_size; j += 8) {3838_mm256_storeu_ps(3839&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));3840}3841for (; j < block_size; j++) {3842op[j] = len_inv * op[j];3843}3844}3845}3846}3847return dataInd == index_size;3848}
3849bool EmbeddingLookupIdx_int32_t_uint8_t_float_false__avx2_fma(3850const int64_t block_size,3851const int64_t output_size,3852const int64_t index_size,3853const int64_t data_size,3854const uint8_t* input,3855const int* indices,3856const int* offsets,3857const float* weights,3858const float* scale_bias,3859bool normalize_by_lengths,3860float* out) {3861return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma<false>(3862block_size,3863output_size,3864index_size,3865data_size,3866input,3867indices,3868offsets,3869weights,3870scale_bias,3871normalize_by_lengths,3872out);3873}
3874bool EmbeddingLookupIdx_int32_t_uint8_t_float_true__avx2_fma(3875const int64_t block_size,3876const int64_t output_size,3877const int64_t index_size,3878const int64_t data_size,3879const uint8_t* input,3880const int* indices,3881const int* offsets,3882const float* weights,3883const float* scale_bias,3884bool normalize_by_lengths,3885float* out) {3886return EmbeddingLookupIdx_int32_t_uint8_t_float__avx2_fma<true>(3887block_size,3888output_size,3889index_size,3890data_size,3891input,3892indices,3893offsets,3894weights,3895scale_bias,3896normalize_by_lengths,3897out);3898}
3899
3900template <bool IS_WEIGHT_POSITIONAL>3901static bool EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma(3902const int64_t block_size,3903const int64_t output_size,3904const int64_t index_size,3905const int64_t data_size,3906const uint8_t* input,3907const int64_t* indices,3908const int64_t* offsets,3909const float* weights,3910const float* scale_bias,3911bool normalize_by_lengths,3912float* out) {3913const int64_t prefdist_T0 = 16;3914// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3915const int64_t fused_block_size = block_size + 0;3916int64_t dataInd = 0;3917if (block_size == 128) {3918// unrolling 16 times3919for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {3920float* op = &out[rangeIndex * block_size];3921__m256 vop0 = _mm256_setzero_ps();3922__m256 vop8 = _mm256_setzero_ps();3923__m256 vop16 = _mm256_setzero_ps();3924__m256 vop24 = _mm256_setzero_ps();3925__m256 vop32 = _mm256_setzero_ps();3926__m256 vop40 = _mm256_setzero_ps();3927__m256 vop48 = _mm256_setzero_ps();3928__m256 vop56 = _mm256_setzero_ps();3929__m256 vop64 = _mm256_setzero_ps();3930__m256 vop72 = _mm256_setzero_ps();3931__m256 vop80 = _mm256_setzero_ps();3932__m256 vop88 = _mm256_setzero_ps();3933__m256 vop96 = _mm256_setzero_ps();3934__m256 vop104 = _mm256_setzero_ps();3935__m256 vop112 = _mm256_setzero_ps();3936__m256 vop120 = _mm256_setzero_ps();3937if (dataInd != offsets[rangeIndex] - offsets[0]) {3938return false;3939}3940int64_t end_offset = offsets[rangeIndex + 1];3941int64_t length = end_offset - offsets[rangeIndex];3942for (int64_t start = dataInd; dataInd < end_offset - offsets[0];3943++dataInd) {3944const int64_t idx = indices[dataInd];3945if (idx < 0 || idx >= data_size) {3946return false;3947}3948float wgt = 1.f;3949// NOLINTNEXTLINE(cppcoreguidelines-init-variables)3950float bio;3951if (weights) {3952wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];3953}3954bio = wgt * scale_bias[2 * idx + 1];3955wgt = wgt * scale_bias[2 * idx];3956__m256 vbio = _mm256_set1_ps(bio);3957__m256 vwgt = _mm256_set1_ps(wgt);3958const uint8_t* ip = &input[idx * fused_block_size];3959const int64_t next_T0 = (dataInd < index_size - prefdist_T0)3960// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3961? (dataInd + prefdist_T0)3962// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)3963: dataInd;3964const int64_t idx_pref_T0 = indices[next_T0];3965if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {3966return false;3967}3968const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];3969vop0 = _mm256_fmadd_ps(3970vwgt,3971_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3972_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),3973_mm256_add_ps(vop0, vbio));3974_mm_prefetch(3975reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);3976vop8 = _mm256_fmadd_ps(3977vwgt,3978_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3979_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),3980_mm256_add_ps(vop8, vbio));3981// skip unnecessary prefetch of (&ip_next_T0[8])3982vop16 = _mm256_fmadd_ps(3983vwgt,3984_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3985_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),3986_mm256_add_ps(vop16, vbio));3987// skip unnecessary prefetch of (&ip_next_T0[16])3988vop24 = _mm256_fmadd_ps(3989vwgt,3990_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3991_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),3992_mm256_add_ps(vop24, vbio));3993// skip unnecessary prefetch of (&ip_next_T0[24])3994vop32 = _mm256_fmadd_ps(3995vwgt,3996_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(3997_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),3998_mm256_add_ps(vop32, vbio));3999// skip unnecessary prefetch of (&ip_next_T0[32])4000vop40 = _mm256_fmadd_ps(4001vwgt,4002_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4003_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),4004_mm256_add_ps(vop40, vbio));4005// skip unnecessary prefetch of (&ip_next_T0[40])4006vop48 = _mm256_fmadd_ps(4007vwgt,4008_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4009_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),4010_mm256_add_ps(vop48, vbio));4011// skip unnecessary prefetch of (&ip_next_T0[48])4012vop56 = _mm256_fmadd_ps(4013vwgt,4014_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4015_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),4016_mm256_add_ps(vop56, vbio));4017// skip unnecessary prefetch of (&ip_next_T0[56])4018vop64 = _mm256_fmadd_ps(4019vwgt,4020_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4021_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (64))))),4022_mm256_add_ps(vop64, vbio));4023_mm_prefetch(4024reinterpret_cast<const char*>(&ip_next_T0[64]), _MM_HINT_T0);4025vop72 = _mm256_fmadd_ps(4026vwgt,4027_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4028_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (72))))),4029_mm256_add_ps(vop72, vbio));4030// skip unnecessary prefetch of (&ip_next_T0[72])4031vop80 = _mm256_fmadd_ps(4032vwgt,4033_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4034_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (80))))),4035_mm256_add_ps(vop80, vbio));4036// skip unnecessary prefetch of (&ip_next_T0[80])4037vop88 = _mm256_fmadd_ps(4038vwgt,4039_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4040_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (88))))),4041_mm256_add_ps(vop88, vbio));4042// skip unnecessary prefetch of (&ip_next_T0[88])4043vop96 = _mm256_fmadd_ps(4044vwgt,4045_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4046_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (96))))),4047_mm256_add_ps(vop96, vbio));4048// skip unnecessary prefetch of (&ip_next_T0[96])4049vop104 = _mm256_fmadd_ps(4050vwgt,4051_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4052_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (104))))),4053_mm256_add_ps(vop104, vbio));4054// skip unnecessary prefetch of (&ip_next_T0[104])4055vop112 = _mm256_fmadd_ps(4056vwgt,4057_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4058_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (112))))),4059_mm256_add_ps(vop112, vbio));4060// skip unnecessary prefetch of (&ip_next_T0[112])4061vop120 = _mm256_fmadd_ps(4062vwgt,4063_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4064_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (120))))),4065_mm256_add_ps(vop120, vbio));4066// skip unnecessary prefetch of (&ip_next_T0[120])4067}4068if (!normalize_by_lengths || length == 0) {4069_mm256_storeu_ps(&op[0], vop0);4070_mm256_storeu_ps(&op[8], vop8);4071_mm256_storeu_ps(&op[16], vop16);4072_mm256_storeu_ps(&op[24], vop24);4073_mm256_storeu_ps(&op[32], vop32);4074_mm256_storeu_ps(&op[40], vop40);4075_mm256_storeu_ps(&op[48], vop48);4076_mm256_storeu_ps(&op[56], vop56);4077_mm256_storeu_ps(&op[64], vop64);4078_mm256_storeu_ps(&op[72], vop72);4079_mm256_storeu_ps(&op[80], vop80);4080_mm256_storeu_ps(&op[88], vop88);4081_mm256_storeu_ps(&op[96], vop96);4082_mm256_storeu_ps(&op[104], vop104);4083_mm256_storeu_ps(&op[112], vop112);4084_mm256_storeu_ps(&op[120], vop120);4085} else {4086__m256 vlen_inv = _mm256_set1_ps(1.0f / length);4087_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));4088_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));4089_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));4090_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));4091_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));4092_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));4093_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));4094_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));4095_mm256_storeu_ps(&op[64], _mm256_mul_ps(vop64, vlen_inv));4096_mm256_storeu_ps(&op[72], _mm256_mul_ps(vop72, vlen_inv));4097_mm256_storeu_ps(&op[80], _mm256_mul_ps(vop80, vlen_inv));4098_mm256_storeu_ps(&op[88], _mm256_mul_ps(vop88, vlen_inv));4099_mm256_storeu_ps(&op[96], _mm256_mul_ps(vop96, vlen_inv));4100_mm256_storeu_ps(&op[104], _mm256_mul_ps(vop104, vlen_inv));4101_mm256_storeu_ps(&op[112], _mm256_mul_ps(vop112, vlen_inv));4102_mm256_storeu_ps(&op[120], _mm256_mul_ps(vop120, vlen_inv));4103}4104}4105} else if (block_size == 64) {4106// unrolling 8 times4107for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {4108float* op = &out[rangeIndex * block_size];4109__m256 vop0 = _mm256_setzero_ps();4110__m256 vop8 = _mm256_setzero_ps();4111__m256 vop16 = _mm256_setzero_ps();4112__m256 vop24 = _mm256_setzero_ps();4113__m256 vop32 = _mm256_setzero_ps();4114__m256 vop40 = _mm256_setzero_ps();4115__m256 vop48 = _mm256_setzero_ps();4116__m256 vop56 = _mm256_setzero_ps();4117if (dataInd != offsets[rangeIndex] - offsets[0]) {4118return false;4119}4120int64_t end_offset = offsets[rangeIndex + 1];4121int64_t length = end_offset - offsets[rangeIndex];4122for (int64_t start = dataInd; dataInd < end_offset - offsets[0];4123++dataInd) {4124const int64_t idx = indices[dataInd];4125if (idx < 0 || idx >= data_size) {4126return false;4127}4128float wgt = 1.f;4129// NOLINTNEXTLINE(cppcoreguidelines-init-variables)4130float bio;4131if (weights) {4132wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];4133}4134bio = wgt * scale_bias[2 * idx + 1];4135wgt = wgt * scale_bias[2 * idx];4136__m256 vbio = _mm256_set1_ps(bio);4137__m256 vwgt = _mm256_set1_ps(wgt);4138const uint8_t* ip = &input[idx * fused_block_size];4139const int64_t next_T0 = (dataInd < index_size - prefdist_T0)4140// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)4141? (dataInd + prefdist_T0)4142// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)4143: dataInd;4144const int64_t idx_pref_T0 = indices[next_T0];4145if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {4146return false;4147}4148const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];4149vop0 = _mm256_fmadd_ps(4150vwgt,4151_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4152_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),4153_mm256_add_ps(vop0, vbio));4154_mm_prefetch(4155reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);4156vop8 = _mm256_fmadd_ps(4157vwgt,4158_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4159_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),4160_mm256_add_ps(vop8, vbio));4161// skip unnecessary prefetch of (&ip_next_T0[8])4162vop16 = _mm256_fmadd_ps(4163vwgt,4164_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4165_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),4166_mm256_add_ps(vop16, vbio));4167// skip unnecessary prefetch of (&ip_next_T0[16])4168vop24 = _mm256_fmadd_ps(4169vwgt,4170_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4171_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),4172_mm256_add_ps(vop24, vbio));4173// skip unnecessary prefetch of (&ip_next_T0[24])4174vop32 = _mm256_fmadd_ps(4175vwgt,4176_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4177_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (32))))),4178_mm256_add_ps(vop32, vbio));4179// skip unnecessary prefetch of (&ip_next_T0[32])4180vop40 = _mm256_fmadd_ps(4181vwgt,4182_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4183_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (40))))),4184_mm256_add_ps(vop40, vbio));4185// skip unnecessary prefetch of (&ip_next_T0[40])4186vop48 = _mm256_fmadd_ps(4187vwgt,4188_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4189_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (48))))),4190_mm256_add_ps(vop48, vbio));4191// skip unnecessary prefetch of (&ip_next_T0[48])4192vop56 = _mm256_fmadd_ps(4193vwgt,4194_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4195_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (56))))),4196_mm256_add_ps(vop56, vbio));4197// skip unnecessary prefetch of (&ip_next_T0[56])4198}4199if (!normalize_by_lengths || length == 0) {4200_mm256_storeu_ps(&op[0], vop0);4201_mm256_storeu_ps(&op[8], vop8);4202_mm256_storeu_ps(&op[16], vop16);4203_mm256_storeu_ps(&op[24], vop24);4204_mm256_storeu_ps(&op[32], vop32);4205_mm256_storeu_ps(&op[40], vop40);4206_mm256_storeu_ps(&op[48], vop48);4207_mm256_storeu_ps(&op[56], vop56);4208} else {4209__m256 vlen_inv = _mm256_set1_ps(1.0f / length);4210_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));4211_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));4212_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));4213_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));4214_mm256_storeu_ps(&op[32], _mm256_mul_ps(vop32, vlen_inv));4215_mm256_storeu_ps(&op[40], _mm256_mul_ps(vop40, vlen_inv));4216_mm256_storeu_ps(&op[48], _mm256_mul_ps(vop48, vlen_inv));4217_mm256_storeu_ps(&op[56], _mm256_mul_ps(vop56, vlen_inv));4218}4219}4220} else if (block_size == 32) {4221// unrolling 4 times4222for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {4223float* op = &out[rangeIndex * block_size];4224__m256 vop0 = _mm256_setzero_ps();4225__m256 vop8 = _mm256_setzero_ps();4226__m256 vop16 = _mm256_setzero_ps();4227__m256 vop24 = _mm256_setzero_ps();4228if (dataInd != offsets[rangeIndex] - offsets[0]) {4229return false;4230}4231int64_t end_offset = offsets[rangeIndex + 1];4232int64_t length = end_offset - offsets[rangeIndex];4233for (int64_t start = dataInd; dataInd < end_offset - offsets[0];4234++dataInd) {4235const int64_t idx = indices[dataInd];4236if (idx < 0 || idx >= data_size) {4237return false;4238}4239float wgt = 1.f;4240// NOLINTNEXTLINE(cppcoreguidelines-init-variables)4241float bio;4242if (weights) {4243wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];4244}4245bio = wgt * scale_bias[2 * idx + 1];4246wgt = wgt * scale_bias[2 * idx];4247__m256 vbio = _mm256_set1_ps(bio);4248__m256 vwgt = _mm256_set1_ps(wgt);4249const uint8_t* ip = &input[idx * fused_block_size];4250const int64_t next_T0 = (dataInd < index_size - prefdist_T0)4251// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)4252? (dataInd + prefdist_T0)4253// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)4254: dataInd;4255const int64_t idx_pref_T0 = indices[next_T0];4256if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {4257return false;4258}4259const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];4260vop0 = _mm256_fmadd_ps(4261vwgt,4262_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4263_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),4264_mm256_add_ps(vop0, vbio));4265_mm_prefetch(4266reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);4267vop8 = _mm256_fmadd_ps(4268vwgt,4269_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4270_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),4271_mm256_add_ps(vop8, vbio));4272// skip unnecessary prefetch of (&ip_next_T0[8])4273vop16 = _mm256_fmadd_ps(4274vwgt,4275_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4276_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (16))))),4277_mm256_add_ps(vop16, vbio));4278// skip unnecessary prefetch of (&ip_next_T0[16])4279vop24 = _mm256_fmadd_ps(4280vwgt,4281_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4282_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (24))))),4283_mm256_add_ps(vop24, vbio));4284// skip unnecessary prefetch of (&ip_next_T0[24])4285}4286if (!normalize_by_lengths || length == 0) {4287_mm256_storeu_ps(&op[0], vop0);4288_mm256_storeu_ps(&op[8], vop8);4289_mm256_storeu_ps(&op[16], vop16);4290_mm256_storeu_ps(&op[24], vop24);4291} else {4292__m256 vlen_inv = _mm256_set1_ps(1.0f / length);4293_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));4294_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));4295_mm256_storeu_ps(&op[16], _mm256_mul_ps(vop16, vlen_inv));4296_mm256_storeu_ps(&op[24], _mm256_mul_ps(vop24, vlen_inv));4297}4298}4299} else if (block_size == 16) {4300// unrolling 2 times4301for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {4302float* op = &out[rangeIndex * block_size];4303__m256 vop0 = _mm256_setzero_ps();4304__m256 vop8 = _mm256_setzero_ps();4305if (dataInd != offsets[rangeIndex] - offsets[0]) {4306return false;4307}4308int64_t end_offset = offsets[rangeIndex + 1];4309int64_t length = end_offset - offsets[rangeIndex];4310for (int64_t start = dataInd; dataInd < end_offset - offsets[0];4311++dataInd) {4312const int64_t idx = indices[dataInd];4313if (idx < 0 || idx >= data_size) {4314return false;4315}4316float wgt = 1.f;4317// NOLINTNEXTLINE(cppcoreguidelines-init-variables)4318float bio;4319if (weights) {4320wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];4321}4322bio = wgt * scale_bias[2 * idx + 1];4323wgt = wgt * scale_bias[2 * idx];4324__m256 vbio = _mm256_set1_ps(bio);4325__m256 vwgt = _mm256_set1_ps(wgt);4326const uint8_t* ip = &input[idx * fused_block_size];4327const int64_t next_T0 = (dataInd < index_size - prefdist_T0)4328// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)4329? (dataInd + prefdist_T0)4330// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)4331: dataInd;4332const int64_t idx_pref_T0 = indices[next_T0];4333if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {4334return false;4335}4336const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];4337vop0 = _mm256_fmadd_ps(4338vwgt,4339_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4340_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (0))))),4341_mm256_add_ps(vop0, vbio));4342_mm_prefetch(4343reinterpret_cast<const char*>(&ip_next_T0[0]), _MM_HINT_T0);4344vop8 = _mm256_fmadd_ps(4345vwgt,4346_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(4347_mm_loadl_epi64(reinterpret_cast<const __m128i*>(ip + (8))))),4348_mm256_add_ps(vop8, vbio));4349// skip unnecessary prefetch of (&ip_next_T0[8])4350}4351if (!normalize_by_lengths || length == 0) {4352_mm256_storeu_ps(&op[0], vop0);4353_mm256_storeu_ps(&op[8], vop8);4354} else {4355__m256 vlen_inv = _mm256_set1_ps(1.0f / length);4356_mm256_storeu_ps(&op[0], _mm256_mul_ps(vop0, vlen_inv));4357_mm256_storeu_ps(&op[8], _mm256_mul_ps(vop8, vlen_inv));4358}4359}4360} else {4361// generic code4362// NOLINTNEXTLINE(modernize-avoid-c-arrays,cppcoreguidelines-avoid-magic-numbers,cppcoreguidelines-avoid-c-arrays)4363for (int64_t rangeIndex = 0; rangeIndex < output_size; ++rangeIndex) {4364float* op = &out[rangeIndex * block_size];4365int64_t j = 0;4366for (; j + 8 <= block_size; j += 8) {4367_mm256_storeu_ps(op + j, _mm256_setzero_ps());4368}4369for (; j < block_size; j++) {4370op[j] = 0.0f;4371}4372if (dataInd != offsets[rangeIndex] - offsets[0]) {4373return false;4374}4375int64_t end_offset = offsets[rangeIndex + 1];4376int64_t length = end_offset - offsets[rangeIndex];4377for (int64_t start = dataInd; dataInd < end_offset - offsets[0];4378++dataInd) {4379const int64_t idx = indices[dataInd];4380if (idx < 0 || idx >= data_size) {4381return false;4382}4383float wgt = 1.f;4384// NOLINTNEXTLINE(cppcoreguidelines-init-variables)4385float bio;4386if (weights) {4387wgt = weights[IS_WEIGHT_POSITIONAL ? (dataInd - start) : dataInd];4388}4389bio = wgt * scale_bias[2 * idx + 1];4390wgt = wgt * scale_bias[2 * idx];4391__m256 vbio = _mm256_set1_ps(bio);4392__m256 vwgt = _mm256_set1_ps(wgt);4393const uint8_t* ip = &input[idx * fused_block_size];4394const int64_t next_T0 = (dataInd < index_size - prefdist_T0)4395// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)4396? (dataInd + prefdist_T0)4397// NOLINTNEXTLINE(cppcoreguidelines-narrowing-conversions,bugprone-narrowing-conversions)4398: dataInd;4399const int64_t idx_pref_T0 = indices[next_T0];4400if (idx_pref_T0 < 0 || idx_pref_T0 >= data_size) {4401return false;4402}4403const uint8_t* ip_next_T0 = &input[idx_pref_T0 * fused_block_size];4404j = 0;4405for (; j + 8 <= block_size; j += 8) {4406_mm256_storeu_ps(4407&op[j],4408_mm256_fmadd_ps(4409vwgt,4410_mm256_cvtepi32_ps(_mm256_cvtepu8_epi32(_mm_loadl_epi64(4411reinterpret_cast<const __m128i*>(&ip[j])))),4412_mm256_add_ps(_mm256_loadu_ps(&op[j]), vbio)));4413_mm_prefetch(4414reinterpret_cast<const char*>(&ip_next_T0[j]), _MM_HINT_T0);4415}4416for (; j < block_size; j++) {4417op[j] = std::fma(wgt, (float)ip[j], bio + op[j]);4418}4419}4420if (normalize_by_lengths && length) {4421float len_inv = 1.0f / length;4422__m256 vlen_inv = _mm256_set1_ps(len_inv);4423j = 0;4424for (; j + 8 <= block_size; j += 8) {4425_mm256_storeu_ps(4426&op[j], _mm256_mul_ps(_mm256_loadu_ps(&op[j]), vlen_inv));4427}4428for (; j < block_size; j++) {4429op[j] = len_inv * op[j];4430}4431}4432}4433}4434return dataInd == index_size;4435}
4436bool EmbeddingLookupIdx_int64_t_uint8_t_float_false__avx2_fma(4437const int64_t block_size,4438const int64_t output_size,4439const int64_t index_size,4440const int64_t data_size,4441const uint8_t* input,4442const int64_t* indices,4443const int64_t* offsets,4444const float* weights,4445const float* scale_bias,4446bool normalize_by_lengths,4447float* out) {4448return EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma<false>(4449block_size,4450output_size,4451index_size,4452data_size,4453input,4454indices,4455offsets,4456weights,4457scale_bias,4458normalize_by_lengths,4459out);4460}
4461bool EmbeddingLookupIdx_int64_t_uint8_t_float_true__avx2_fma(4462const int64_t block_size,4463const int64_t output_size,4464const int64_t index_size,4465const int64_t data_size,4466const uint8_t* input,4467const int64_t* indices,4468const int64_t* offsets,4469const float* weights,4470const float* scale_bias,4471bool normalize_by_lengths,4472float* out) {4473return EmbeddingLookupIdx_int64_t_uint8_t_float__avx2_fma<true>(4474block_size,4475output_size,4476index_size,4477data_size,4478input,4479indices,4480offsets,4481weights,4482scale_bias,4483normalize_by_lengths,4484out);4485}
4486
4487} // namespace caffe24488