2
#include "layernorm_fp16_fake_op.h"
3
#include "caffe2/contrib/fakelowp/common.h"
4
#include "caffe2/contrib/fakelowp/fp16_fma.h"
8
void LayerNormUtils::calcY(
17
ConstEigenArrayMap<float> X_arr(X, N, M);
18
ConstEigenVectorArrayMap<float> mean_arr(mean, M);
19
ConstEigenVectorArrayMap<float> std_arr(std, M);
20
EigenArrayMap<float> Y_arr(Y, N, M);
22
std::vector<float> normalized(N);
23
for (int i = 0; i < M; ++i) {
24
float normFactor = float(1.0f / std_arr[i]);
25
fbgemm::RoundToFloat16(&normFactor, &normFactor, 1, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
27
for (int j = 0; j < N; ++j) {
28
normalized[j] = X_arr.col(i)[j] - mean[i];
30
fbgemm::RoundToFloat16(normalized.data(), normalized.data(), N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
31
for (int j = 0; j < N; ++j) {
32
normalized[j] *= normFactor;
34
fbgemm::RoundToFloat16(normalized.data(), &Y_arr.col(i)[0], N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
37
if (gamma != nullptr && beta != nullptr) {
38
ConstEigenVectorArrayMap<float> gamma_arr(gamma, N);
39
ConstEigenVectorArrayMap<float> beta_arr(beta, N);
41
for (int i = 0; i < M; ++i) {
43
for (int j = 0; j < N; j++) {
46
fake_fp16::fma_fp16(N, &Y_arr.col(i)[0], gamma, res.data());
47
for (int j = 0; j < N; j++) {
48
Y_arr.col(i)[j] = res[j];
54
float LayerNormUtils::ReducedAdd(const std::vector<float>& vec) {
55
constexpr int VEC_SIZE = 32;
56
std::vector<float> v(vec.begin(), vec.end());
58
for (int factor = 2; factor <=32; factor *=2) {
59
int range = VEC_SIZE / factor;
61
for (int i = 0; i < range; ++i) { // 16
62
v[i] = v[2 * i] + v[2 * i + 1];
64
fbgemm::RoundToFloat16(v.data(), v.data(), range, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
70
void LayerNormUtils::calcMeanStd(
77
ConstEigenArrayMap<float> X_arr(X, N, M);
79
std::vector<float> sqr(M, 0.0f);
80
std::vector<float> var(M, 0.0f);
81
float inv_N_val = 1.0f / N;
82
fbgemm::RoundToFloat16(&inv_N_val, &inv_N_val, 1, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
84
constexpr int VEC_SIZE = 32;
85
std::vector<float> inv_N_vec(VEC_SIZE, inv_N_val);
86
std::vector<float> inv_N_prod_vec(VEC_SIZE, 0);
87
std::vector<float> avgVec(VEC_SIZE, 0.0f);
88
std::vector<float> sqrVec(VEC_SIZE, 0.0f);
89
std::vector<float> negMeanVec(M, 0.0f);
90
int numVecs = N / VEC_SIZE;
91
int tailSize = N - (numVecs * VEC_SIZE);
93
vector<float> X_fp16(M * N);
94
fbgemm::RoundToFloat16(
95
X, X_fp16.data(), M * N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
97
for (int i = 0; i < M; ++i) {
98
std::fill(avgVec.begin(), avgVec.end(), 0.0f);
99
std::fill(sqrVec.begin(), sqrVec.end(), 0.0f);
100
for (int j = 0; j < numVecs; ++j) {
103
&X_fp16[i * N + VEC_SIZE * j],
106
for (int k = 0; k < VEC_SIZE; k++) {
107
inv_N_prod_vec[k] = X_fp16[i * N + VEC_SIZE * j + k] * inv_N_val;
109
fbgemm::RoundToFloat16(
110
inv_N_prod_vec.data(),
111
inv_N_prod_vec.data(),
113
FLAGS_caffe2_fbgemm_fake_fp16_clamp);
117
&X_fp16[i * N + VEC_SIZE * j],
118
inv_N_prod_vec.data(),
125
&X_fp16[i * N + VEC_SIZE * numVecs],
128
for (int k = 0; k < tailSize; k++) {
129
inv_N_prod_vec[k] = X_fp16[i * N + VEC_SIZE * numVecs + k] * inv_N_val;
131
fbgemm::RoundToFloat16(
132
inv_N_prod_vec.data(),
133
inv_N_prod_vec.data(),
135
FLAGS_caffe2_fbgemm_fake_fp16_clamp);
139
&X_fp16[i * N + VEC_SIZE * numVecs],
140
inv_N_prod_vec.data(),
143
mean[i] = ReducedAdd(avgVec);
144
sqr[i] = ReducedAdd(sqrVec);
147
// // compute variance and std deviation
148
std::copy(mean, mean + M, negMeanVec.begin());
149
std::transform(negMeanVec.cbegin(),
152
std::negate<float>());
153
fake_fp16::fma_fp16(M, mean, negMeanVec.data(), sqr.data());
154
std::copy(sqr.cbegin(), sqr.cend(), var.begin());
157
std::vector<float> tmpVec(M, 0.0f);
158
fbgemm::RoundToFloat16(&teps, &teps, 1, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
162
LOG_EVERY_N(WARNING, 1000) << "Variance " << v
163
<< " negative, resetting to 0.";
166
tmpVec[i] = var[i] + teps;
169
fbgemm::RoundToFloat16(
173
FLAGS_caffe2_fbgemm_fake_fp16_clamp);
175
for (auto& v: tmpVec) {
177
LOG_EVERY_N(WARNING, 1000) << "Variance " << v
178
<< " negative, resetting to 0.";
181
std[i] = std::sqrt(v);
184
fbgemm::RoundToFloat16(
188
FLAGS_caffe2_fbgemm_fake_fp16_clamp);
191
REGISTER_CPU_OPERATOR(LayerNormFakeFP16NNPI, LayerNormFakeFp16Op<false>);
192
OPERATOR_SCHEMA(LayerNormFakeFP16NNPI).NumInputs({1, 3}).NumOutputs(3);
194
REGISTER_CPU_OPERATOR(LayerNormInt8QuantizeFakeNNPI,
195
LayerNormFakeFp16Op<true>);
196
OPERATOR_SCHEMA(LayerNormInt8QuantizeFakeNNPI)
197
.IdenticalTypeAndShape()