pytorch

Форк
0
/
layernorm_fp16_fake_op.cc 
201 строка · 5.7 Кб
1
#include <algorithm>
2
#include "layernorm_fp16_fake_op.h"
3
#include "caffe2/contrib/fakelowp/common.h"
4
#include "caffe2/contrib/fakelowp/fp16_fma.h"
5

6
namespace caffe2 {
7

8
void LayerNormUtils::calcY(
9
    const int M,
10
    const int N,
11
    const float* X,
12
    const float* mean,
13
    const float* std,
14
    const float* gamma,
15
    const float* beta,
16
    float* Y) {
17
  ConstEigenArrayMap<float> X_arr(X, N, M);
18
  ConstEigenVectorArrayMap<float> mean_arr(mean, M);
19
  ConstEigenVectorArrayMap<float> std_arr(std, M);
20
  EigenArrayMap<float> Y_arr(Y, N, M);
21

22
  std::vector<float> normalized(N);
23
  for (int i = 0; i < M; ++i) {
24
    float normFactor = float(1.0f / std_arr[i]);
25
    fbgemm::RoundToFloat16(&normFactor, &normFactor, 1, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
26

27
    for (int j = 0; j < N; ++j) {
28
      normalized[j] = X_arr.col(i)[j] - mean[i];
29
    }
30
    fbgemm::RoundToFloat16(normalized.data(), normalized.data(), N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
31
    for (int j = 0; j < N; ++j) {
32
      normalized[j] *= normFactor;
33
    }
34
    fbgemm::RoundToFloat16(normalized.data(), &Y_arr.col(i)[0], N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
35
  }
36

37
  if (gamma != nullptr && beta != nullptr) {
38
    ConstEigenVectorArrayMap<float> gamma_arr(gamma, N);
39
    ConstEigenVectorArrayMap<float> beta_arr(beta, N);
40

41
    for (int i = 0; i < M; ++i) {
42
      vector<float> res(N);
43
      for (int j = 0; j < N; j++) {
44
        res[j] = beta[j];
45
      }
46
      fake_fp16::fma_fp16(N, &Y_arr.col(i)[0], gamma, res.data());
47
      for (int j = 0; j < N; j++) {
48
        Y_arr.col(i)[j] = res[j];
49
      }
50
    }
51
  }
52
}
53

54
float LayerNormUtils::ReducedAdd(const std::vector<float>& vec) {
55
  constexpr int VEC_SIZE = 32;
56
  std::vector<float> v(vec.begin(), vec.end());
57

58
  for (int factor = 2; factor <=32; factor *=2) {
59
    int range = VEC_SIZE / factor;
60

61
    for (int i = 0; i < range; ++i) { // 16
62
      v[i] = v[2 * i] + v[2 * i + 1];
63
    }
64
    fbgemm::RoundToFloat16(v.data(), v.data(), range, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
65
  }
66

67
  return v[0];
68
}
69

70
void LayerNormUtils::calcMeanStd(
71
    const int M,
72
    const int N,
73
    const float eps,
74
    const float* X,
75
    float* mean,
76
    float* std) {
77
  ConstEigenArrayMap<float> X_arr(X, N, M);
78

79
  std::vector<float> sqr(M, 0.0f);
80
  std::vector<float> var(M, 0.0f);
81
  float inv_N_val = 1.0f / N;
82
  fbgemm::RoundToFloat16(&inv_N_val, &inv_N_val, 1, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
83

84
  constexpr int VEC_SIZE = 32;
85
  std::vector<float> inv_N_vec(VEC_SIZE, inv_N_val);
86
  std::vector<float> inv_N_prod_vec(VEC_SIZE, 0);
87
  std::vector<float> avgVec(VEC_SIZE, 0.0f);
88
  std::vector<float> sqrVec(VEC_SIZE, 0.0f);
89
  std::vector<float> negMeanVec(M, 0.0f);
90
  int numVecs = N / VEC_SIZE;
91
  int tailSize = N - (numVecs * VEC_SIZE);
92

93
  vector<float> X_fp16(M * N);
94
  fbgemm::RoundToFloat16(
95
      X, X_fp16.data(), M * N, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
96

97
  for (int i = 0; i < M; ++i) {
98
    std::fill(avgVec.begin(), avgVec.end(), 0.0f);
99
    std::fill(sqrVec.begin(), sqrVec.end(), 0.0f);
100
    for (int j = 0; j < numVecs; ++j) {
101
      fake_fp16::fma_fp16(
102
          VEC_SIZE,
103
          &X_fp16[i * N + VEC_SIZE * j],
104
          inv_N_vec.data(),
105
          avgVec.data());
106
      for (int k = 0; k < VEC_SIZE; k++) {
107
        inv_N_prod_vec[k] = X_fp16[i * N + VEC_SIZE * j + k] * inv_N_val;
108
      }
109
      fbgemm::RoundToFloat16(
110
          inv_N_prod_vec.data(),
111
          inv_N_prod_vec.data(),
112
          VEC_SIZE,
113
          FLAGS_caffe2_fbgemm_fake_fp16_clamp);
114

115
      fake_fp16::fma_fp16(
116
          VEC_SIZE,
117
          &X_fp16[i * N + VEC_SIZE * j],
118
          inv_N_prod_vec.data(),
119
          sqrVec.data());
120
    }
121

122
    if (tailSize > 0) {
123
      fake_fp16::fma_fp16(
124
          tailSize,
125
          &X_fp16[i * N + VEC_SIZE * numVecs],
126
          inv_N_vec.data(),
127
          avgVec.data());
128
      for (int k = 0; k < tailSize; k++) {
129
        inv_N_prod_vec[k] = X_fp16[i * N + VEC_SIZE * numVecs + k] * inv_N_val;
130
      }
131
      fbgemm::RoundToFloat16(
132
          inv_N_prod_vec.data(),
133
          inv_N_prod_vec.data(),
134
          tailSize,
135
          FLAGS_caffe2_fbgemm_fake_fp16_clamp);
136

137
      fake_fp16::fma_fp16(
138
          tailSize,
139
          &X_fp16[i * N + VEC_SIZE * numVecs],
140
          inv_N_prod_vec.data(),
141
          sqrVec.data());
142
    }
143
    mean[i] = ReducedAdd(avgVec);
144
    sqr[i] = ReducedAdd(sqrVec);
145
  }
146

147
  // // compute variance and std deviation
148
  std::copy(mean, mean + M, negMeanVec.begin());
149
  std::transform(negMeanVec.cbegin(),
150
      negMeanVec.cend(),
151
      negMeanVec.begin(),
152
      std::negate<float>());
153
  fake_fp16::fma_fp16(M, mean, negMeanVec.data(), sqr.data());
154
  std::copy(sqr.cbegin(), sqr.cend(), var.begin());
155

156
  float teps = eps;
157
  std::vector<float> tmpVec(M, 0.0f);
158
  fbgemm::RoundToFloat16(&teps, &teps, 1, FLAGS_caffe2_fbgemm_fake_fp16_clamp);
159
  int i = 0;
160
  for (auto& v: var) {
161
    if (v < 0.0) {
162
      LOG_EVERY_N(WARNING, 1000) << "Variance " << v
163
          << " negative, resetting to 0.";
164
      v = 0.0;
165
    }
166
    tmpVec[i] = var[i] + teps;
167
    ++i;
168
  }
169
  fbgemm::RoundToFloat16(
170
      tmpVec.data(),
171
      tmpVec.data(),
172
      M,
173
      FLAGS_caffe2_fbgemm_fake_fp16_clamp);
174
  i = 0;
175
  for (auto& v: tmpVec) {
176
    if (v < 0) {
177
      LOG_EVERY_N(WARNING, 1000) << "Variance " << v
178
          << " negative, resetting to 0.";
179
      v = 0.0;
180
    }
181
    std[i] = std::sqrt(v);
182
    ++i;
183
  }
184
  fbgemm::RoundToFloat16(
185
    std,
186
    std,
187
    M,
188
    FLAGS_caffe2_fbgemm_fake_fp16_clamp);
189
}
190

191
REGISTER_CPU_OPERATOR(LayerNormFakeFP16NNPI, LayerNormFakeFp16Op<false>);
192
OPERATOR_SCHEMA(LayerNormFakeFP16NNPI).NumInputs({1, 3}).NumOutputs(3);
193

194
REGISTER_CPU_OPERATOR(LayerNormInt8QuantizeFakeNNPI,
195
                      LayerNormFakeFp16Op<true>);
196
OPERATOR_SCHEMA(LayerNormInt8QuantizeFakeNNPI)
197
    .IdenticalTypeAndShape()
198
    .NumInputs({1, 3})
199
    .NumOutputs(3);
200

201
} // namespace caffe2
202

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.