TransformerEngine
115 строк · 3.7 Кб
1/*************************************************************************
2* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3*
4* See LICENSE for license information.
5************************************************************************/
6
7#include <cmath>
8#include <cstring>
9#include <iomanip>
10#include <iostream>
11#include <memory>
12#include <random>
13#include <type_traits>
14
15#include <cuda_bf16.h>
16#include <cuda_runtime.h>
17#include <gtest/gtest.h>
18
19#include <transformer_engine/activation.h>
20#include "../test_common.h"
21
22using namespace transformer_engine;
23
24template <typename IT, typename OT, typename CT>
25void compute_ref_geglu_cast(const IT *input_h, OT *output_h, const CT scale, CT *amax_h,
26const size_t N, const size_t H) {
27CT amax = 0.;
28
29const int col = H * 2;
30
31for (size_t i = 0; i < N; i++) {
32for (size_t j = 0; j < H; j++) {
33CT gelu_elt = CT(input_h[i * col + j]);
34gelu_elt = 0.5f * gelu_elt *
35(1.0f + tanhf(0.79788456F * gelu_elt * (1.0f + 0.044715f * gelu_elt * gelu_elt)));
36CT gate_elt = CT(input_h[i * col + H + j]);
37CT elt = gelu_elt * gate_elt;
38output_h[i * H + j] = OT(scale * elt);
39amax = std::abs(elt) > amax ? std::abs(elt) : amax;
40}
41}
42
43*amax_h = amax;
44}
45
46template <typename IType, typename OType>
47void performTestGEGLU(const size_t N, const size_t H) {
48using namespace test;
49
50DType itype = TypeInfo<IType>::dtype;
51DType otype = TypeInfo<OType>::dtype;
52
53Tensor input({N, H * 2}, itype);
54Tensor output({N, H}, otype);
55
56fillUniform(&input);
57setRandomScale(&output);
58
59std::unique_ptr<OType[]> ref_output = std::make_unique<OType[]>(N * H);
60
61nvte_geglu(input.data(), output.data(), 0);
62
63float ref_amax;
64compute_ref_geglu_cast(input.cpu_dptr<IType>(), ref_output.get(), output.scale(), &ref_amax, N,
65H);
66
67cudaDeviceSynchronize();
68auto err = cudaGetLastError();
69ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
70
71if (otype == DType::kFloat8E4M3 || otype == DType::kFloat8E5M2) {
72auto [atol_amax, rtol_amax] = getTolerances(DType::kFloat32);
73compareResults("amax", output.amax(), ref_amax, atol_amax, rtol_amax);
74}
75auto [atol, rtol] = getTolerances(otype);
76compareResults("output_gelu", output, ref_output.get(), atol, rtol);
77}
78
79class GeGLUTestSuite
80: public ::testing::TestWithParam<std::tuple<
81transformer_engine::DType, transformer_engine::DType, std::pair<size_t, size_t>>> {};
82
83TEST_P(GeGLUTestSuite, TestGeGLU) {
84using namespace transformer_engine;
85using namespace test;
86
87const DType input_type = std::get<0>(GetParam());
88const DType output_type = std::get<1>(GetParam());
89const auto size = std::get<2>(GetParam());
90
91TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
92input_type, InputType,
93TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(
94output_type, OutputType,
95performTestGEGLU<InputType, OutputType>(size.first, size.second);););
96}
97
98namespace {
99
100std::vector<std::pair<size_t, size_t>> test_cases = {
101{4096, 2048}, {768, 2816}, {256, 5120}, {128, 10240}, {256, 256}, {257, 259}, {128, 128 + 1}};
102
103} // namespace
104
105INSTANTIATE_TEST_SUITE_P(
106OperatorTest, GeGLUTestSuite,
107::testing::Combine(::testing::Values(DType::kFloat32, DType::kBFloat16, DType::kFloat16),
108::testing::ValuesIn(test::all_fp_types), ::testing::ValuesIn(test_cases)),
109[](const testing::TestParamInfo<GeGLUTestSuite::ParamType> &info) {
110std::string name = test::typeName(std::get<0>(info.param)) + "X" +
111test::typeName(std::get<1>(info.param)) + "X" +
112std::to_string(std::get<2>(info.param).first) + "X" +
113std::to_string(std::get<2>(info.param).second);
114return name;
115});
116