TransformerEngine

Форк
0
115 строк · 3.7 Кб
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
 *
4
 * See LICENSE for license information.
5
 ************************************************************************/
6

7
#include <cmath>
8
#include <cstring>
9
#include <iomanip>
10
#include <iostream>
11
#include <memory>
12
#include <random>
13
#include <type_traits>
14

15
#include <cuda_bf16.h>
16
#include <cuda_runtime.h>
17
#include <gtest/gtest.h>
18

19
#include <transformer_engine/activation.h>
20
#include "../test_common.h"
21

22
using namespace transformer_engine;
23

24
template <typename IT, typename OT, typename CT>
25
void compute_ref_geglu_cast(const IT *input_h, OT *output_h, const CT scale, CT *amax_h,
26
                            const size_t N, const size_t H) {
27
  CT amax = 0.;
28

29
  const int col = H * 2;
30

31
  for (size_t i = 0; i < N; i++) {
32
    for (size_t j = 0; j < H; j++) {
33
      CT gelu_elt = CT(input_h[i * col + j]);
34
      gelu_elt = 0.5f * gelu_elt *
35
                 (1.0f + tanhf(0.79788456F * gelu_elt * (1.0f + 0.044715f * gelu_elt * gelu_elt)));
36
      CT gate_elt = CT(input_h[i * col + H + j]);
37
      CT elt = gelu_elt * gate_elt;
38
      output_h[i * H + j] = OT(scale * elt);
39
      amax = std::abs(elt) > amax ? std::abs(elt) : amax;
40
    }
41
  }
42

43
  *amax_h = amax;
44
}
45

46
template <typename IType, typename OType>
47
void performTestGEGLU(const size_t N, const size_t H) {
48
  using namespace test;
49

50
  DType itype = TypeInfo<IType>::dtype;
51
  DType otype = TypeInfo<OType>::dtype;
52

53
  Tensor input({N, H * 2}, itype);
54
  Tensor output({N, H}, otype);
55

56
  fillUniform(&input);
57
  setRandomScale(&output);
58

59
  std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N * H);
60

61
  nvte_geglu(input.data(), output.data(), 0);
62

63
  float ref_amax;
64
  compute_ref_geglu_cast(input.cpu_dptr<IType>(), ref_output.get(), output.scale(), &ref_amax, N,
65
                         H);
66

67
  cudaDeviceSynchronize();
68
  auto err = cudaGetLastError();
69
  ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
70

71
  if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
72
    auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
73
    compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
74
  }
75
  auto [atol, rtol] = getTolerances(otype);
76
  compareResults("output_gelu", output, ref_output.get(), atol, rtol);
77
}
78

79
class GeGLUTestSuite
80
    : public ::testing::TestWithParam<std::tuple<
81
          transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
82

83
TEST_P(GeGLUTestSuite, TestGeGLU) {
84
  using namespace transformer_engine;
85
  using namespace test;
86

87
  const DType input_type = std::get<0>(GetParam());
88
  const DType output_type = std::get<1>(GetParam());
89
  const auto size = std::get<2>(GetParam());
90

91
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
92
      input_type, InputType,
93
      TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
94
          output_type, OutputType,
95
          performTestGEGLU<InputType, OutputType>(size.first, size.second);););
96
}
97

98
namespace {
99

100
std::vector<std::pair<size_t, size_t>> test_cases = {
101
    {4096, 2048}, {768, 2816}, {256, 5120}, {128, 10240}, {256, 256}, {257, 259}, {128, 128 + 1}};
102

103
}  // namespace
104

105
INSTANTIATE_TEST_SUITE_P(
106
    OperatorTest, GeGLUTestSuite,
107
    ::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
108
                       ::testing::ValuesIn(test::all_fp_types), ::testing::ValuesIn(test_cases)),
109
    [](const testing::TestParamInfo<GeGLUTestSuite::ParamType> &info) {
110
      std::string name = test::typeName(std::get<0>(info.param)) + "X" +
111
                         test::typeName(std::get<1>(info.param)) + "X" +
112
                         std::to_string(std::get<2>(info.param).first) + "X" +
113
                         std::to_string(std::get<2>(info.param).second);
114
      return name;
115
    });
116

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.