TransformerEngine

Форк
0
123 строки · 4.2 Кб
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
 *
4
 * See LICENSE for license information.
5
 ************************************************************************/
6

7
#include <cmath>
8
#include <cstring>
9
#include <memory>
10
#include <iomanip>
11
#include <iostream>
12
#include <random>
13
#include <type_traits>
14

15
#include <cuda_bf16.h>
16
#include <cuda_runtime.h>
17
#include <gtest/gtest.h>
18

19
#include <transformer_engine/activation.h>
20
#include "../test_common.h"
21

22
using namespace transformer_engine;
23

24
template <typename IT, typename OT, typename CT>
25
void compute_ref_gelu_cast(const IT *input_h,
26
                           OT *output_h,
27
                           const CT scale,
28
                           CT *amax_h,
29
                           const size_t N,
30
                           const size_t H) {
31
  CT amax  = 0.;
32

33
  for (size_t i = 0; i < N; i++) {
34
    for (size_t j = 0; j < H; j++) {
35
      CT elt = CT(input_h[i * H + j]);
36
      elt = 0.5f * elt * (1.0f + tanhf(0.79788456F * elt *
37
                                       (1.0f + 0.044715f * elt * elt)));
38
      output_h[i * H + j] = OT(scale * elt);
39
      amax = std::abs(elt) > amax ? std::abs(elt) : amax;
40
    }
41
  }
42

43
  *amax_h = amax;
44
}
45

46
template <typename IType, typename OType>
47
void performTestGelu(const size_t N, const size_t H) {
48
  using namespace test;
49

50
  DType itype = TypeInfo<IType>::dtype;
51
  DType otype = TypeInfo<OType>::dtype;
52

53
  Tensor input({ N, H }, itype);
54
  Tensor output({ N, H }, otype);
55

56
  fillUniform(&input);
57
  setRandomScale(&output);
58

59
  std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N*H);
60

61
  nvte_gelu(input.data(), output.data(), 0);
62

63
  float ref_amax;
64
  compute_ref_gelu_cast(input.cpu_dptr<IType>(), ref_output.get(),
65
                        output.scale(), &ref_amax, N, H);
66

67
  cudaDeviceSynchronize();
68
  auto err = cudaGetLastError();
69
  ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
70

71
  if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
72
    auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
73
    compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
74
  }
75
  auto [atol, rtol] = getTolerances(otype);
76
  compareResults("output_gelu", output, ref_output.get(), atol, rtol);
77
}
78

79
class GELUTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
80
                                                                 transformer_engine::DType,
81
                                                                 std::pair<size_t, size_t>>> {};
82

83
TEST_P(GELUTestSuite, TestGELU) {
84
    using namespace transformer_engine;
85
    using namespace test;
86

87
    const DType input_type = std::get<0>(GetParam());
88
    const DType output_type = std::get<1>(GetParam());
89
    const auto size = std::get<2>(GetParam());
90

91
    TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(input_type, InputType,
92
      TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(output_type, OutputType,
93
        performTestGelu<InputType, OutputType>(size.first, size.second);
94
      );
95
    );
96
}
97

98
namespace {
99

100
std::vector<std::pair<size_t, size_t>> gelu_test_cases = {{2048, 12288},
101
                                                          {768, 1024},
102
                                                          {256, 65536},
103
                                                          {65536, 128},
104
                                                          {256, 256},
105
                                                          {257, 259},
106
                                                          {128, 128+1}};
107

108
}  // namespace
109

110
INSTANTIATE_TEST_SUITE_P(
111
    OperatorTest,
112
    GELUTestSuite,
113
    ::testing::Combine(
114
        ::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
115
        ::testing::ValuesIn(test::all_fp_types),
116
        ::testing::ValuesIn(gelu_test_cases)),
117
    [](const testing::TestParamInfo<GELUTestSuite::ParamType>& info) {
118
      std::string name = test::typeName(std::get<0>(info.param)) + "X" +
119
                         test::typeName(std::get<1>(info.param)) + "X" +
120
                         std::to_string(std::get<2>(info.param).first) + "X" +
121
                         std::to_string(std::get<2>(info.param).second);
122
      return name;
123
    });
124

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.