TransformerEngine
98 строк · 3.1 Кб
1/*************************************************************************
2* Copyright (c) 2022-2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3*
4* See LICENSE for license information.
5************************************************************************/
6
7#include <cstring>
8#include <iomanip>
9#include <iostream>
10#include <memory>
11#include <random>
12
13#include <cuda_bf16.h>
14#include <cuda_runtime.h>
15#include <gtest/gtest.h>
16
17#include <transformer_engine/transpose.h>
18#include "../test_common.h"
19
20using namespace transformer_engine;
21
22namespace {
23
24template <typename Type>
25void compute_ref(const Type *data, Type *output,
26const size_t N, const size_t H) {
27for (size_t i = 0; i < N; ++i) {
28for (size_t j = 0; j < H; ++j) {
29output[j * N + i] = data[i * H + j];
30}
31}
32}
33
34template <typename Type>
35void performTest(const size_t N, const size_t H) {
36using namespace test;
37
38DType dtype = TypeInfo<Type>::dtype;
39
40Tensor input({ N, H }, dtype);
41Tensor output({ H, N }, dtype);
42
43std::unique_ptr<Type[]> ref_output = std::make_unique<Type[]>(N * H);
44
45fillUniform(&input);
46
47nvte_transpose(input.data(), output.data(), 0);
48
49compute_ref<Type>(input.cpu_dptr<Type>(), ref_output.get(), N, H);
50
51cudaDeviceSynchronize();
52auto err = cudaGetLastError();
53ASSERT_EQ(err, cudaSuccess) << cudaGetErrorString(err);
54auto [atol, rtol] = getTolerances(dtype);
55compareResults("output", output, ref_output.get(), atol, rtol);
56}
57
58std::vector<std::pair<size_t, size_t>> test_cases = {{2048, 12288},
59{768, 1024},
60{256, 65536},
61{65536, 128},
62{256, 256},
63{120, 2080},
64{8, 8},
65{1223, 1583}, // Primes 200, 250
66{1, 541}, // Prime 100
67{1987, 1}}; // Prime 300
68} // namespace
69
70class TTestSuite : public ::testing::TestWithParam<std::tuple<transformer_engine::DType,
71std::pair<size_t, size_t>>> {};
72
73TEST_P(TTestSuite, TestTranspose) {
74using namespace transformer_engine;
75using namespace test;
76
77const DType type = std::get<0>(GetParam());
78const auto size = std::get<1>(GetParam());
79
80TRANSFORMER_ENGINE_TYPE_SWITCH_ALL(type, T,
81performTest<T>(size.first, size.second);
82);
83}
84
85
86
87INSTANTIATE_TEST_SUITE_P(
88OperatorTest,
89TTestSuite,
90::testing::Combine(
91::testing::ValuesIn(test::all_fp_types),
92::testing::ValuesIn(test_cases)),
93[](const testing::TestParamInfo<TTestSuite::ParamType>& info) {
94std::string name = test::typeName(std::get<0>(info.param)) + "X" +
95std::to_string(std::get<1>(info.param).first) + "X" +
96std::to_string(std::get<1>(info.param).second);
97return name;
98});
99