TransformerEngine

Форк
0
/
test_transpose.cu 
98 строк · 3.1 Кб
1
/*************************************************************************
2
 * Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3
 *
4
 * See LICENSE for license information.
5
 ************************************************************************/
6

7
#include <cstring>
8
#include <iomanip>
9
#include <iostream>
10
#include <memory>
11
#include <random>
12

13
#include <cuda_bf16.h>
14
#include <cuda_runtime.h>
15
#include <gtest/gtest.h>
16

17
#include <transformer_engine/transpose.h>
18
#include "../test_common.h"
19

20
using namespace transformer_engine;
21

22
namespace {
23

24
template <typename Type>
25
void compute_ref(const Type *data,  Type *output,
26
                 const size_t N, const size_t H) {
27
  for (size_t i = 0; i < N; ++i) {
28
    for (size_t j = 0; j < H; ++j) {
29
      output[j * N + i] = data[i * H + j];
30
    }
31
  }
32
}
33

34
template <typename Type>
35
void performTest(const size_t N, const size_t H) {
36
  using namespace test;
37

38
  DType dtype = TypeInfo<Type>::dtype;
39

40
  Tensor input({ N, H }, dtype);
41
  Tensor output({ H, N }, dtype);
42

43
  std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H);
44

45
  fillUniform(&input);
46

47
  nvte_transpose(input.data(), output.data(), 0);
48

49
  compute_ref<Type>(input.cpu_dptr<Type>(), ref_output.get(), N, H);
50

51
  cudaDeviceSynchronize();
52
  auto err = cudaGetLastError();
53
  ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
54
  auto [atol, rtol] = getTolerances(dtype);
55
  compareResults("output", output, ref_output.get(), atol, rtol);
56
}
57

58
std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
59
                                                     {768, 1024},
60
                                                     {256, 65536},
61
                                                     {65536, 128},
62
                                                     {256, 256},
63
                                                     {120, 2080},
64
                                                     {8, 8},
65
                                                     {1223, 1583}, // Primes 200, 250
66
                                                     {1, 541},     // Prime 100
67
                                                     {1987, 1}};   // Prime 300
68
}  // namespace
69

70
class TTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
71
                                                              std::pair<size_t, size_t>>> {};
72

73
TEST_P(TTestSuite, TestTranspose) {
74
  using namespace transformer_engine;
75
  using namespace test;
76

77
  const DType type = std::get<0>(GetParam());
78
  const auto size = std::get<1>(GetParam());
79

80
  TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
81
    performTest<T>(size.first, size.second);
82
  );
83
}
84

85

86

87
INSTANTIATE_TEST_SUITE_P(
88
  OperatorTest,
89
  TTestSuite,
90
  ::testing::Combine(
91
      ::testing::ValuesIn(test::all_fp_types),
92
      ::testing::ValuesIn(test_cases)),
93
  [](const testing::TestParamInfo<TTestSuite::ParamType>& info) {
94
    std::string name = test::typeName(std::get<0>(info.param)) + "X" +
95
                       std::to_string(std::get<1>(info.param).first) + "X" +
96
                       std::to_string(std::get<1>(info.param).second);
97
    return name;
98
  });
99

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.