TransformerEngine
263 строки · 8.3 Кб
1/*************************************************************************
2* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3*
4* See LICENSE for license information.
5************************************************************************/
6
7
8#include "test_common.h"
9
10#include <algorithm>
11#include <memory>
12#include <random>
13
14#include <gtest/gtest.h>
15
16#include <transformer_engine/transformer_engine.h>
17#include "util/logging.h"
18
19namespace test {
20
21std::vector<DType> all_fp_types = {DType::kFloat32,
22DType::kFloat16,
23DType::kBFloat16,
24DType::kFloat8E5M2,
25DType::kFloat8E4M3};
26
27bool areShapesEqual(const NVTEShape &s1, const NVTEShape &s2) {
28if (s1.ndim != s2.ndim) return false;
29
30for (size_t i = 0; i < s1.ndim; ++i) {
31if (s1.data[i] != s2.data[i]) return false;
32}
33
34return true;
35}
36
37size_t typeToSize(DType type) {
38TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
39{
40return TypeInfo<T>::size;
41});
42}
43
44const std::string &typeName(DType type) {
45static const std::unordered_map<DType, std::string> name_map = {
46{DType::kByte, "byte"},
47{DType::kInt32, "int32"},
48{DType::kInt64, "int64"},
49{DType::kFloat32, "float32"},
50{DType::kFloat16, "float16"},
51{DType::kBFloat16, "bfloat16"},
52{DType::kFloat8E4M3, "float8e4m3"},
53{DType::kFloat8E5M2, "float8e5m2"}};
54return name_map.at(type);
55}
56
57size_t product(const NVTEShape &shape) {
58size_t ret = 1;
59for (size_t i = 0; i < shape.ndim; ++i) {
60ret *= shape.data[i];
61}
62return ret;
63}
64
65Tensor::Tensor(const NVTEShape &shape, const DType type) {
66size_t s = typeToSize(type);
67size_t total_size = product(shape) * s;
68void *dptr = nullptr;
69cpu_data_ = nullptr;
70amax_cpu_data_ = nullptr;
71scale_cpu_data_ = nullptr;
72scale_inv_cpu_data_ = nullptr;
73float *amax = nullptr, *scale = nullptr, *scale_inv = nullptr;
74if (total_size != 0) {
75cudaMalloc((void**)&dptr, total_size); // NOLINT(*)
76cudaMemset(dptr, 0, total_size);
77cpu_data_ = std::make_unique<unsigned char[]>(total_size);
78for (size_t i = 0; i < total_size; ++i) {
79cpu_data_[i] = 0;
80}
81}
82if (isFp8Type(type)) {
83cudaMalloc((void**)&amax, sizeof(float)); // NOLINT(*)
84cudaMemset(amax, 0, sizeof(float));
85cudaMalloc((void**)&scale, sizeof(float)); // NOLINT(*)
86cudaMemset(scale, 0, sizeof(float));
87cudaMalloc((void**)&scale_inv, sizeof(float)); // NOLINT(*)
88cudaMemset(scale_inv, 0, sizeof(float));
89amax_cpu_data_ = std::make_shared<float>();
90*amax_cpu_data_ = 0;
91scale_cpu_data_ = std::make_shared<float>();
92*scale_cpu_data_ = 0;
93scale_inv_cpu_data_ = std::make_shared<float>();
94*scale_inv_cpu_data_ = 0;
95}
96tensor_ = TensorWrapper(dptr, shape, type, amax, scale, scale_inv);
97}
98
99void Tensor::to_cpu() const {
100const NVTEShape s = tensor_.shape();
101const size_t size = product(s) * typeToSize(tensor_.dtype());
102cudaMemcpy(cpu_data_.get(), tensor_.dptr(), size, cudaMemcpyDeviceToHost);
103if (isFp8Type(dtype())) {
104cudaMemcpy(amax_cpu_data_.get(), tensor_.amax(), sizeof(float),
105cudaMemcpyDeviceToHost);
106cudaMemcpy(scale_cpu_data_.get(), tensor_.scale(), sizeof(float),
107cudaMemcpyDeviceToHost);
108cudaMemcpy(scale_inv_cpu_data_.get(), tensor_.scale_inv(), sizeof(float),
109cudaMemcpyDeviceToHost);
110}
111}
112
113void Tensor::from_cpu() const {
114const NVTEShape s = tensor_.shape();
115const size_t size = product(s) * typeToSize(tensor_.dtype());
116cudaMemcpy(tensor_.dptr(), cpu_data_.get(), size, cudaMemcpyHostToDevice);
117if (isFp8Type(dtype())) {
118cudaMemcpy(tensor_.amax(), amax_cpu_data_.get(), sizeof(float),
119cudaMemcpyHostToDevice);
120cudaMemcpy(tensor_.scale(), scale_cpu_data_.get(), sizeof(float),
121cudaMemcpyHostToDevice);
122cudaMemcpy(tensor_.scale_inv(), scale_inv_cpu_data_.get(), sizeof(float),
123cudaMemcpyHostToDevice);
124}
125}
126
127void Tensor::set_scale(float scale) {
128if (isFp8Type(dtype())) {
129NVTE_CHECK(scale_cpu_data_);
130*scale_cpu_data_ = scale;
131from_cpu();
132}
133}
134
135void Tensor::set_scale_inv(float scale_inv) {
136if (isFp8Type(dtype())) {
137NVTE_CHECK(scale_inv_cpu_data_);
138*scale_inv_cpu_data_ = scale_inv;
139from_cpu();
140}
141}
142
143void Tensor::shareFP8Meta(const Tensor &other) {
144if(isFp8Type(dtype()) && isFp8Type(other.dtype())) {
145tensor_ = TensorWrapper(dptr(), shape(), dtype(),
146other.tensor_.amax(),
147other.tensor_.scale(),
148other.tensor_.scale_inv());
149to_cpu();
150}
151}
152
153using std::to_string;
154
155template <typename T>
156std::string to_string(const std::vector<T> &v) {
157std::string s = "[";
158for (const auto x : v) {
159s += to_string(x) + ", ";
160}
161s.pop_back();
162s.pop_back();
163return s + "]";
164}
165
166std::vector<size_t> unravel(const size_t i, const NVTEShape &shape) {
167std::vector<size_t> ret;
168size_t current_i = i;
169for (size_t current = shape.ndim - 1;
170current > 0;
171--current) {
172ret.push_back(current_i % shape.data[current]);
173current_i /= shape.data[current];
174}
175ret.push_back(current_i);
176std::reverse(ret.begin(), ret.end());
177return ret;
178}
179
180void compareResults(const std::string &name, const Tensor &test, const void *ref,
181double atol, double rtol) {
182test.to_cpu();
183const size_t N = product(test.shape());
184TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(test.dtype(), T,
185const T *test_data = test.cpu_dptr<T>();
186const T *ref_data = reinterpret_cast<const T*>(ref);
187for (size_t i = 0; i < N; ++i) {
188double t = static_cast<double>(test_data[i]);
189double r = static_cast<double>(ref_data[i]);
190bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
191/* For Float32 the floating point comparison is enough to error out */
192bool assertion = mismatch && test.dtype() == DType::kFloat32;
193if (mismatch && !assertion) {
194/* Check if it is just a failure of round to nearest choosing different
195side of the real value */
196const double mean = (t + r) / 2;
197const double mean_p = mean >= 0 ? mean * (1 + 1e-6) : mean * (1 - 1e-6);
198const double mean_m = mean >= 0 ? mean * (1 - 1e-6) : mean * (1 + 1e-6);
199const double cast_mean_p = static_cast<double>(static_cast<T>(mean_p));
200const double cast_mean_m = static_cast<double>(static_cast<T>(mean_m));
201assertion = !(cast_mean_m == std::min(t,r) && cast_mean_p == std::max(t,r));
202}
203ASSERT_FALSE(assertion) << "Error in tensor " << name << std::endl
204<< "Mismatch at place " << to_string(unravel(i, test.shape()))
205<< " (" << std::to_string(i) << "): " << t << " vs " << r;
206
207}
208);
209}
210
211void compareResults(const std::string &name, const float test, const float ref,
212double atol, double rtol) {
213double t = static_cast<double>(test);
214double r = static_cast<double>(ref);
215bool mismatch = fabs(t - r) > atol && (r == 0 || fabs((t - r) / r) > rtol);
216ASSERT_FALSE(mismatch) << "Error in " << name << std::endl
217<< "Mismatch: " << t << " vs " << r;
218
219}
220
221std::pair<double, double> getTolerances(const DType type) {
222switch(type) {
223case DType::kFloat32:
224return {1e-6, 5e-6};
225case DType::kFloat16:
226return {1e-5, 1e-3};
227case DType::kBFloat16:
228return {1e-5, 1e-2};
229case DType::kFloat8E4M3:
230case DType::kFloat8E5M2:
231return {1e-2, 1e-2};
232default:
233NVTE_CHECK("Invalid type!");
234}
235return {0, 0};
236}
237
238void fillUniform(Tensor *t) {
239const size_t size = product(t->shape());
240static std::mt19937 gen(12345);
241std::uniform_real_distribution<> dis(-2.0, 1.0);
242TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(t->dtype(), T, {
243T *data = t->cpu_dptr<T>();
244for (size_t i = 0; i < size; ++i) {
245data[i] = T(dis(gen));
246}
247});
248t->set_scale_inv(dis(gen));
249t->from_cpu();
250}
251
252void setRandomScale(Tensor *t) {
253static std::mt19937 gen(12345);
254std::uniform_real_distribution<> dis(-2.0, 1.0);
255const float scale = dis(gen);
256t->set_scale(scale);
257}
258
259bool isFp8Type(DType type) {
260return type == DType::kFloat8E4M3 || type == DType::kFloat8E5M2;
261}
262
263} // namespace test
264