pytorch
234 строки · 6.7 Кб
1#include <gtest/gtest.h>2#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>3#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>4#include <torch/csrc/jit/tensorexpr/loopnest.h>5#include <torch/csrc/jit/tensorexpr/operators/conv2d.h>6#include <torch/csrc/jit/tensorexpr/tensor.h>7#include <torch/torch.h>8
9namespace torch {10namespace jit {11
12namespace te = torch::jit::tensorexpr;13namespace F = torch::nn::functional;14
15#ifdef TORCH_ENABLE_LLVM16
17// Generate test data with few bits of precision, to minimize error
18// accumulation from floating-point reordering.
19static at::Tensor genTestData(c10::IntArrayRef args) {20return at::trunc(at::randn(args) * 256.0f) / 256.0f;21}
22
23TEST(Conv, DepthwiseConv2D) {24constexpr int N = 1, C = 72, H = 56, W = 56;25constexpr int K = 72, R = 3, S = 3;26constexpr int kPad = 1, kStride = 2, kGroups = C;27constexpr int CperG = C / kGroups;28
29te::BufHandle input("input", {N, C, H, W}, te::kFloat);30te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat);31te::BufHandle bias("bias", {K}, te::kFloat);32te::Tensor output =33te::conv2d_depthwise(input, weight, bias, kStride, kPad, kGroups);34
35te::LoopNest loop({output});36loop.simplify();37loop.prepareForCodegen();38te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, bias, output});39
40auto it = genTestData({N, C, H, W});41auto wt = genTestData({K, CperG, R, S});42auto bt = genTestData({K});43auto ref = at::conv2d(it, wt, bt, kStride, kPad, /*dilation=*/1, kGroups);44auto ot = at::zeros_like(ref);45cg.call(46{it.data_ptr<float>(),47wt.data_ptr<float>(),48bt.data_ptr<float>(),49ot.data_ptr<float>()});50
51ASSERT_TRUE(at::allclose(ref, ot));52}
53
54TEST(Conv, DepthwiseConv2DNoBias) {55constexpr int N = 1, C = 72, H = 56, W = 56;56constexpr int K = 72, R = 3, S = 3;57constexpr int kPad = 1, kStride = 2, kGroups = C;58constexpr int CperG = C / kGroups;59
60te::BufHandle input("input", {N, C, H, W}, te::kFloat);61te::BufHandle weight("weight", {K, CperG, R, S}, te::kFloat);62te::Tensor output =63te::conv2d_depthwise(input, weight, kStride, kPad, kGroups);64
65te::LoopNest loop({output});66loop.simplify();67loop.prepareForCodegen();68te::LLVMCodeGen cg(loop.root_stmt(), {input, weight, output});69
70auto it = genTestData({N, C, H, W});71auto wt = genTestData({K, CperG, R, S});72auto ref =73at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups);74auto ot = at::zeros_like(ref);75cg.call({it.data_ptr<float>(), wt.data_ptr<float>(), ot.data_ptr<float>()});76
77ASSERT_TRUE(at::allclose(ref, ot));78}
79
80TEST(Conv, DepthwiseConv2DDynamicShapes) {81te::VarHandle N_var("N", te::kInt);82te::VarHandle C_var("C", te::kInt);83te::VarHandle H_var("H", te::kInt);84te::VarHandle W_var("W", te::kInt);85te::VarHandle K_var("K", te::kInt);86te::VarHandle CperG_var("CperG", te::kInt);87te::VarHandle R_var("R", te::kInt);88te::VarHandle S_var("S", te::kInt);89te::VarHandle kPad_var("kPad", te::kInt);90te::VarHandle kStride_var("kStride", te::kInt);91te::VarHandle kGroups_var("kGroups", te::kInt);92
93te::BufHandle input("input", {N_var, C_var, H_var, W_var}, te::kFloat);94te::BufHandle weight("weight", {K_var, CperG_var, R_var, S_var}, te::kFloat);95te::Tensor output = te::conv2d_depthwise(96input,97weight,98N_var,99C_var,100H_var,101W_var,102K_var,103CperG_var,104R_var,105S_var,106kStride_var,107kPad_var,108kGroups_var);109
110te::LoopNest loop({output});111loop.simplify();112loop.prepareForCodegen();113std::vector<te::CodeGen::BufferArg> buffer_args = {114input,115weight,116N_var,117C_var,118H_var,119W_var,120K_var,121CperG_var,122R_var,123S_var,124kPad_var,125kStride_var,126kGroups_var,127output};128te::LLVMCodeGen cg(loop.root_stmt(), buffer_args);129
130constexpr int N = 1, C = 72, H = 56, W = 56;131constexpr int K = 72, R = 3, S = 3;132constexpr int kPad = 1, kStride = 2, kGroups = C;133constexpr int CperG = C / kGroups;134
135auto it = genTestData({N, C, H, W});136auto wt = genTestData({K, CperG, R, S});137auto ref =138at::conv2d(it, wt, at::Tensor(), kStride, kPad, /*dilation=*/1, kGroups);139auto ot = at::zeros_like(ref);140std::vector<te::CodeGen::CallArg> call_args = {141it.data_ptr<float>(),142wt.data_ptr<float>(),143N,144C,145H,146W,147K,148CperG,149R,150S,151kPad,152kStride,153kGroups,154ot.data_ptr<float>()};155cg.call(call_args);156
157ASSERT_TRUE(at::allclose(ref, ot));158}
159
160#endif161
162TEST(Conv, Conv2D) {163// Input dimensions.164constexpr int N = 1;165constexpr int C = 3;166constexpr int H = 11;167constexpr int W = 11;168
169// Filter dimensions.170constexpr int K = 8;171constexpr int R = 3;172constexpr int S = 3;173
174// Output dims.175constexpr int OH = H - R + 1;176constexpr int OW = W - S + 1;177
178// Compute reference result.179at::Tensor input = torch::randn({N, C, H, W});180at::Tensor filter = torch::randn({K, C, R, S});181at::Tensor ref = F::conv2d(input, filter);182
183// Double check the output size is as expected.184ASSERT_EQ(ref.size(0), N);185ASSERT_EQ(ref.size(1), K);186ASSERT_EQ(ref.size(2), OH);187ASSERT_EQ(ref.size(3), OW);188
189te::BufHandle inputB("input", {N, C, H, W}, te::kFloat);190te::BufHandle filterB("filter", {K, C, R, S}, te::kFloat);191
192te::Tensor conv = te::Reduce(193"conv",194{N, K, OH, OW},195te::Sum(),196// FIXME: We have to use a `std::vector` parameter here and then unpack197// it, because we don't have an overload allowing for an arbitrary number198// of ExprHandle/VarHandle parameters.199[&](const std::vector<te::VarHandle>& v) {200auto const& n = v[0];201auto const& k = v[1];202auto const& oh = v[2];203auto const& ow = v[3];204auto const& c = v[4];205auto const& r = v[5];206auto const& s = v[6];207// FIXME: We have to use `call` and construct a `std::vector` here208// because the `operator()` overload is only specialized for a small209// number of arguments.210return inputB.load(n, c, oh + r, ow + s) * filterB.load(k, c, r, s);211},212// FIXME: If you forget one of the reduction dims, you get a segfault.213// Could that be caught by a verifier?214{C, R, S});215
216// FIXME: It'd be nice to have a single header that pulls in things like217// LoopNest, IRSimplifier, etc.218te::LoopNest loop({conv});219loop.prepareForCodegen();220te::StmtPtr s = loop.root_stmt();221s = te::IRSimplifier::simplify(s);222
223at::Tensor result = at::empty_like(ref);224te::SimpleIREvaluator cg(s, {inputB, filterB, conv});225cg.call(226{input.data_ptr<float>(),227filter.data_ptr<float>(),228result.data_ptr<float>()});229
230ASSERT_TRUE(at::allclose(ref, result, 1e-3, 1e-3));231}
232
233} // namespace jit234} // namespace torch235