TransformerEngine

Форк
0
/
test_common.cu 
263 строки · 8.3 Кб
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
 *
4
 * See LICENSE for license information.
5
 ************************************************************************/
6

7

8
#include "test_common.h"
9

10
#include <algorithm>
11
#include <memory>
12
#include <random>
13

14
#include <gtest/gtest.h>
15

16
#include <transformer_engine/transformer_engine.h>
17
#include "util/logging.h"
18

19
namespace test {
20

21
std::vector<DType> all_fp_types = {DType::kFloat32,
22
                                   DType::kFloat16,
23
                                   DType::kBFloat16,
24
                                   DType::kFloat8E5M2,
25
                                   DType::kFloat8E4M3};
26

27
bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
28
  if (s1.ndim != s2.ndim) return false;
29

30
  for (size_t i = 0; i < s1.ndim; ++i) {
31
    if (s1.data[i] != s2.data[i]) return false;
32
  }
33

34
  return true;
35
}
36

37
size_t typeToSize(DType type) {
38
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
39
  {
40
      return TypeInfo<T>::size;
41
  });
42
}
43

44
const std::string &typeName(DType type) {
45
  static const std::unordered_map<DType, std::string> name_map = {
46
    {DType::kByte, "byte"},
47
    {DType::kInt32, "int32"},
48
    {DType::kInt64, "int64"},
49
    {DType::kFloat32, "float32"},
50
    {DType::kFloat16, "float16"},
51
    {DType::kBFloat16, "bfloat16"},
52
    {DType::kFloat8E4M3, "float8e4m3"},
53
    {DType::kFloat8E5M2, "float8e5m2"}};
54
  return name_map.at(type);
55
}
56

57
size_t product(const NVTEShape &shape) {
58
    size_t ret = 1;
59
    for (size_t i = 0; i < shape.ndim; ++i) {
60
      ret *= shape.data[i];
61
    }
62
    return ret;
63
}
64

65
Tensor::Tensor(const NVTEShape &shape, const DType type) {
66
    size_t s = typeToSize(type);
67
    size_t total_size = product(shape) * s;
68
    void *dptr = nullptr;
69
    cpu_data_ = nullptr;
70
    amax_cpu_data_ = nullptr;
71
    scale_cpu_data_ = nullptr;
72
    scale_inv_cpu_data_ = nullptr;
73
    float *amax = nullptr, *scale = nullptr, *scale_inv = nullptr;
74
    if (total_size != 0) {
75
        cudaMalloc((void**)&dptr, total_size);  // NOLINT(*)
76
        cudaMemset(dptr, 0, total_size);
77
        cpu_data_ = std::make_unique<unsigned char[]>(total_size);
78
        for (size_t i = 0; i < total_size; ++i) {
79
          cpu_data_[i] = 0;
80
        }
81
    }
82
    if (isFp8Type(type)) {
83
      cudaMalloc((void**)&amax, sizeof(float));  // NOLINT(*)
84
      cudaMemset(amax, 0, sizeof(float));
85
      cudaMalloc((void**)&scale, sizeof(float));  // NOLINT(*)
86
      cudaMemset(scale, 0, sizeof(float));
87
      cudaMalloc((void**)&scale_inv, sizeof(float));  // NOLINT(*)
88
      cudaMemset(scale_inv, 0, sizeof(float));
89
      amax_cpu_data_ = std::make_shared<float>();
90
      *amax_cpu_data_ = 0;
91
      scale_cpu_data_ = std::make_shared<float>();
92
      *scale_cpu_data_ = 0;
93
      scale_inv_cpu_data_ = std::make_shared<float>();
94
      *scale_inv_cpu_data_ = 0;
95
    }
96
    tensor_ = TensorWrapper(dptr, shape, type, amax, scale, scale_inv);
97
}
98

99
void Tensor::to_cpu() const {
100
  const NVTEShape s = tensor_.shape();
101
  const size_t size = product(s) * typeToSize(tensor_.dtype());
102
  cudaMemcpy(cpu_data_.get(), tensor_.dptr(), size, cudaMemcpyDeviceToHost);
103
  if (isFp8Type(dtype())) {
104
  cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), sizeof(float),
105
             cudaMemcpyDeviceToHost);
106
  cudaMemcpy(scale_cpu_data_.get(), tensor_.scale(), sizeof(float),
107
             cudaMemcpyDeviceToHost);
108
  cudaMemcpy(scale_inv_cpu_data_.get(), tensor_.scale_inv(), sizeof(float),
109
             cudaMemcpyDeviceToHost);
110
  }
111
}
112

113
void Tensor::from_cpu() const {
114
  const NVTEShape s = tensor_.shape();
115
  const size_t size = product(s) * typeToSize(tensor_.dtype());
116
  cudaMemcpy(tensor_.dptr(), cpu_data_.get(), size, cudaMemcpyHostToDevice);
117
  if (isFp8Type(dtype())) {
118
  cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float),
119
             cudaMemcpyHostToDevice);
120
  cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
121
             cudaMemcpyHostToDevice);
122
  cudaMemcpy(tensor_.scale_inv(), scale_inv_cpu_data_.get(), sizeof(float),
123
             cudaMemcpyHostToDevice);
124
  }
125
}
126

127
void Tensor::set_scale(float scale) {
128
  if (isFp8Type(dtype())) {
129
    NVTE_CHECK(scale_cpu_data_);
130
    *scale_cpu_data_ = scale;
131
    from_cpu();
132
  }
133
}
134

135
void Tensor::set_scale_inv(float scale_inv) {
136
  if (isFp8Type(dtype())) {
137
    NVTE_CHECK(scale_inv_cpu_data_);
138
    *scale_inv_cpu_data_ = scale_inv;
139
    from_cpu();
140
  }
141
}
142

143
void Tensor::shareFP8Meta(const Tensor &other) {
144
  if(isFp8Type(dtype()) && isFp8Type(other.dtype())) {
145
    tensor_ = TensorWrapper(dptr(), shape(), dtype(),
146
                            other.tensor_.amax(),
147
                            other.tensor_.scale(),
148
                            other.tensor_.scale_inv());
149
    to_cpu();
150
  }
151
}
152

153
using std::to_string;
154

155
template <typename T>
156
std::string to_string(const std::vector<T> &v) {
157
  std::string s = "[";
158
  for (const auto x : v) {
159
    s += to_string(x) + ", ";
160
  }
161
  s.pop_back();
162
  s.pop_back();
163
  return s + "]";
164
}
165

166
std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
167
  std::vector<size_t> ret;
168
  size_t current_i = i;
169
  for (size_t current = shape.ndim - 1;
170
       current > 0;
171
       --current) {
172
    ret.push_back(current_i % shape.data[current]);
173
    current_i /= shape.data[current];
174
  }
175
  ret.push_back(current_i);
176
  std::reverse(ret.begin(), ret.end());
177
  return ret;
178
}
179

180
void compareResults(const std::string &name, const Tensor &test, const void *ref,
181
                    double atol, double rtol) {
182
  test.to_cpu();
183
  const size_t N = product(test.shape());
184
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
185
    const T *test_data = test.cpu_dptr<T>();
186
    const T *ref_data = reinterpret_cast<const T*>(ref);
187
    for (size_t i = 0; i < N; ++i) {
188
      double t = static_cast<double>(test_data[i]);
189
      double r = static_cast<double>(ref_data[i]);
190
      bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
191
      /* For Float32 the floating point comparison is enough to error out */
192
      bool assertion = mismatch && test.dtype() == DType::kFloat32;
193
      if (mismatch && !assertion) {
194
        /* Check if it is just a failure of round to nearest choosing different
195
           side of the real value */
196
        const double mean = (t + r) / 2;
197
        const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
198
        const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
199
        const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
200
        const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
201
        assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
202
      }
203
      ASSERT_FALSE(assertion) << "Error in tensor " << name << std::endl
204
                              << "Mismatch at place " << to_string(unravel(i, test.shape()))
205
                              << " (" << std::to_string(i) << "): " << t << " vs " << r;
206

207
    }
208
  );
209
}
210

211
void compareResults(const std::string &name, const float test, const float ref,
212
                    double atol, double rtol) {
213
  double t = static_cast<double>(test);
214
  double r = static_cast<double>(ref);
215
  bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
216
  ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
217
                         << "Mismatch: " << t << " vs " << r;
218

219
}
220

221
std::pair<double, double> getTolerances(const DType type) {
222
  switch(type) {
223
    case DType::kFloat32:
224
      return {1e-6, 5e-6};
225
    case DType::kFloat16:
226
      return {1e-5, 1e-3};
227
    case DType::kBFloat16:
228
      return {1e-5, 1e-2};
229
    case DType::kFloat8E4M3:
230
    case DType::kFloat8E5M2:
231
      return {1e-2, 1e-2};
232
    default:
233
      NVTE_CHECK("Invalid type!");
234
  }
235
  return {0, 0};
236
}
237

238
void fillUniform(Tensor *t) {
239
  const size_t size = product(t->shape());
240
  static std::mt19937 gen(12345);
241
  std::uniform_real_distribution<> dis(-2.0, 1.0);
242
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, {
243
      T *data = t->cpu_dptr<T>();
244
      for (size_t i = 0; i < size; ++i) {
245
          data[i] = T(dis(gen));
246
      }
247
  });
248
  t->set_scale_inv(dis(gen));
249
  t->from_cpu();
250
}
251

252
void setRandomScale(Tensor *t) {
253
  static std::mt19937 gen(12345);
254
  std::uniform_real_distribution<> dis(-2.0, 1.0);
255
  const float scale = dis(gen);
256
  t->set_scale(scale);
257
}
258

259
bool isFp8Type(DType type) {
260
    return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2;
261
}
262

263
}  // namespace test
264

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.