TransformerEngine

Форк
0
/
test_cast_transpose_dgeglu.cu 
147 строк · 5.3 Кб
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
 *
4
 * See LICENSE for license information.
5
 ************************************************************************/
6

7
#include <cmath>
8
#include <cstring>
9
#include <iomanip>
10
#include <iostream>
11
#include <memory>
12
#include <random>
13

14
#include <cuda_bf16.h>
15
#include <cuda_runtime.h>
16
#include <gtest/gtest.h>
17

18
#include <transformer_engine/transpose.h>
19
#include "../test_common.h"
20

21
using namespace transformer_engine;
22

23
namespace {
24

25
template <typename CType, typename IType>
26
inline CType gelu(const IType val) {
27
  CType cval = val;
28
  return cval * (0.5f + 0.5f * tanhf(cval * (0.79788456f + 0.03567741f * cval * cval)));
29
}
30

31
template <typename CType, typename IType>
32
inline CType dgelu(const IType val) {
33
  CType cval = val;
34
  const CType tanh_out = tanhf(0.79788456f * cval * (1.f + 0.044715f * cval * cval));
35
  return 0.5f * cval * ((1.f - tanh_out * tanh_out) * (0.79788456f + 0.1070322243f * cval * cval)) +
36
         0.5f * (1.f + tanh_out);
37
}
38

39
template <typename IT, typename OT, typename CT>
40
void compute_ref_cast_transpose_dgated_gelu(const IT *grad_h, const IT *input_h, const CT scale,
41
                                            OT *output_c_h, OT *output_t_h, CT *amax_h,
42
                                            const size_t N, const size_t H) {
43
  CT amax = 0.;
44

45
  const size_t col = H * 2;
46
  for (size_t i = 0; i < N; i++) {
47
    for (size_t j = 0; j < H; j++) {
48
      CT grad_elt = CT(grad_h[i * H + j]);
49
      CT gelu_elt = CT(input_h[i * col + j]);
50
      CT gate_elt = CT(input_h[i * col + H + j]);
51

52
      CT after_dgelu = dgelu<CT, CT>(gelu_elt) * grad_elt * gate_elt;
53
      CT after_dgate = grad_elt * gelu<CT, CT>(gelu_elt);
54

55
      amax = std::abs(after_dgelu) > amax ? std::abs(after_dgelu) : amax;
56
      amax = std::abs(after_dgate) > amax ? std::abs(after_dgate) : amax;
57

58
      output_c_h[i * col + j] = static_cast<OT>(scale * after_dgelu);
59
      output_c_h[i * col + H + j] = static_cast<OT>(scale * after_dgate);
60

61
      output_t_h[j * N + i] = static_cast<OT>(scale * after_dgelu);
62
      output_t_h[(j + H) * N + i] = static_cast<OT>(scale * after_dgate);
63
    }
64
  }
65

66
  *amax_h = amax;
67
}
68

69
template <typename IType, typename OType>
70
void performTest(const size_t N, const size_t H) {
71
  using namespace test;
72
  using CType = fp32;
73

74
  DType itype = TypeInfo<IType>::dtype;
75
  DType otype = TypeInfo<OType>::dtype;
76

77
  Tensor grad({N, H}, itype);
78
  Tensor input({N, H * 2}, itype);
79
  Tensor output_c({N, H * 2}, otype);
80
  Tensor output_t({H * 2, N}, otype);
81

82
  fillUniform(&grad);
83
  fillUniform(&input);
84
  setRandomScale(&output_c);
85
  output_t.shareFP8Meta(output_c);
86

87
  std::unique_ptr<OType[]> ref_output_c = std::make_unique<OType[]>(N * H * 2);
88
  std::unique_ptr<OType[]> ref_output_t = std::make_unique<OType[]>(N * H * 2);
89

90
  nvte_dgeglu_cast_transpose(grad.data(), input.data(), output_c.data(), output_t.data(), 0);
91

92
  CType ref_amax;
93
  compute_ref_cast_transpose_dgated_gelu(grad.cpu_dptr<IType>(), input.cpu_dptr<IType>(),
94
                                         output_c.scale(), ref_output_c.get(), ref_output_t.get(),
95
                                         &ref_amax, N, H);
96

97
  cudaDeviceSynchronize();
98
  auto err = cudaGetLastError();
99
  ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
100

101
  if (isFp8Type(otype)) {
102
    auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
103
    compareResults("amax", output_c.amax(), ref_amax, atol_amax, rtol_amax);
104
    float ref_scale_inv = 1.f / output_c.scale();
105
    compareResults("scale_inv", output_c.scale_inv(), ref_scale_inv, atol_amax, rtol_amax);
106
  }
107

108
  auto [atol, rtol] = getTolerances(otype);
109
  compareResults("output_c", output_c, ref_output_c.get(), atol, rtol);
110
  compareResults("output_t", output_t, ref_output_t.get(), atol, rtol);
111
}
112

113
std::vector<std::pair<size_t, size_t>> test_cases = {{64, 400},   {4096, 2048}, {768, 2816},
114
                                                     {256, 5120}, {128, 10240}, {256, 256}};
115

116
}  // namespace
117

118
class DGeGLUCTTestSuite
119
    : public ::testing::TestWithParam<std::tuple<
120
          transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
121

122
TEST_P(DGeGLUCTTestSuite, TestDGeGLUCT) {
123
  using namespace transformer_engine;
124
  using namespace test;
125

126
  const DType input_type = std::get<0>(GetParam());
127
  const DType output_type = std::get<1>(GetParam());
128
  const auto size = std::get<2>(GetParam());
129

130
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
131
      input_type, InputType,
132
      TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
133
          output_type, OutputType, performTest<InputType, OutputType>(size.first, size.second);););
134
}
135

136
INSTANTIATE_TEST_SUITE_P(
137
    OperatorTest, DGeGLUCTTestSuite,
138
    ::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
139
                       ::testing::Values(DType::kFloat8E5M2, DType::kFloat8E4M3),
140
                       ::testing::ValuesIn(test_cases)),
141
    [](const testing::TestParamInfo<DGeGLUCTTestSuite::ParamType> &info) {
142
      std::string name = test::typeName(std::get<0>(info.param)) + "X" +
143
                         test::typeName(std::get<1>(info.param)) + "X" +
144
                         std::to_string(std::get<2>(info.param).first) + "X" +
145
                         std::to_string(std::get<2>(info.param).second);
146
      return name;
147
    });
148

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.