TransformerEngine
147 строк · 5.3 Кб
1/*************************************************************************
2* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3*
4* See LICENSE for license information.
5************************************************************************/
6
7#include <cmath>
8#include <cstring>
9#include <iomanip>
10#include <iostream>
11#include <memory>
12#include <random>
13
14#include <cuda_bf16.h>
15#include <cuda_runtime.h>
16#include <gtest/gtest.h>
17
18#include <transformer_engine/transpose.h>
19#include "../test_common.h"
20
21using namespace transformer_engine;
22
23namespace {
24
25template <typename CType, typename IType>
26inline CType gelu(const IType val) {
27CType cval = val;
28return cval * (0.5f + 0.5f * tanhf(cval * (0.79788456f + 0.03567741f * cval * cval)));
29}
30
31template <typename CType, typename IType>
32inline CType dgelu(const IType val) {
33CType cval = val;
34const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
35return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
360.5f * (1.f + tanh_out);
37}
38
39template <typename IT, typename OT, typename CT>
40void compute_ref_cast_transpose_dgated_gelu(const IT *grad_h, const IT *input_h, const CT scale,
41OT *output_c_h, OT *output_t_h, CT *amax_h,
42const size_t N, const size_t H) {
43CT amax = 0.;
44
45const size_t col = H * 2;
46for (size_t i = 0; i < N; i++) {
47for (size_t j = 0; j < H; j++) {
48CT grad_elt = CT(grad_h[i * H + j]);
49CT gelu_elt = CT(input_h[i * col + j]);
50CT gate_elt = CT(input_h[i * col + H + j]);
51
52CT after_dgelu = dgelu<CT, CT>(gelu_elt) * grad_elt * gate_elt;
53CT after_dgate = grad_elt * gelu<CT, CT>(gelu_elt);
54
55amax = std::abs(after_dgelu) > amax ? std::abs(after_dgelu) : amax;
56amax = std::abs(after_dgate) > amax ? std::abs(after_dgate) : amax;
57
58output_c_h[i * col + j] = static_cast<OT>(scale * after_dgelu);
59output_c_h[i * col + H + j] = static_cast<OT>(scale * after_dgate);
60
61output_t_h[j * N + i] = static_cast<OT>(scale * after_dgelu);
62output_t_h[(j + H) * N + i] = static_cast<OT>(scale * after_dgate);
63}
64}
65
66*amax_h = amax;
67}
68
69template <typename IType, typename OType>
70void performTest(const size_t N, const size_t H) {
71using namespace test;
72using CType = fp32;
73
74DType itype = TypeInfo<IType>::dtype;
75DType otype = TypeInfo<OType>::dtype;
76
77Tensor grad({N, H}, itype);
78Tensor input({N, H * 2}, itype);
79Tensor output_c({N, H * 2}, otype);
80Tensor output_t({H * 2, N}, otype);
81
82fillUniform(&grad);
83fillUniform(&input);
84setRandomScale(&output_c);
85output_t.shareFP8Meta(output_c);
86
87std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N * H * 2);
88std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N * H * 2);
89
90nvte_dgeglu_cast_transpose(grad.data(), input.data(), output_c.data(), output_t.data(), 0);
91
92CType ref_amax;
93compute_ref_cast_transpose_dgated_gelu(grad.cpu_dptr<IType>(), input.cpu_dptr<IType>(),
94output_c.scale(), ref_output_c.get(), ref_output_t.get(),
95&ref_amax, N, H);
96
97cudaDeviceSynchronize();
98auto err = cudaGetLastError();
99ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
100
101if (isFp8Type(otype)) {
102auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
103compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
104float ref_scale_inv = 1.f / output_c.scale();
105compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
106}
107
108auto [atol, rtol] = getTolerances(otype);
109compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
110compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
111}
112
113std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400}, {4096, 2048}, {768, 2816},
114{256, 5120}, {128, 10240}, {256, 256}};
115
116} // namespace
117
118class DGeGLUCTTestSuite
119: public ::testing::TestWithParam<std::tuple<
120transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
121
122TEST_P(DGeGLUCTTestSuite, TestDGeGLUCT) {
123using namespace transformer_engine;
124using namespace test;
125
126const DType input_type = std::get<0>(GetParam());
127const DType output_type = std::get<1>(GetParam());
128const auto size = std::get<2>(GetParam());
129
130TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
131input_type, InputType,
132TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
133output_type, OutputType, performTest<InputType, OutputType>(size.first, size.second);););
134}
135
136INSTANTIATE_TEST_SUITE_P(
137OperatorTest, DGeGLUCTTestSuite,
138::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
139::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
140::testing::ValuesIn(test_cases)),
141[](const testing::TestParamInfo<DGeGLUCTTestSuite::ParamType> &info) {
142std::string name = test::typeName(std::get<0>(info.param)) + "X" +
143test::typeName(std::get<1>(info.param)) + "X" +
144std::to_string(std::get<2>(info.param).first) + "X" +
145std::to_string(std::get<2>(info.param).second);
146return name;
147});
148