pytorch
1058 строк · 36.7 Кб
1#include <gtest/gtest.h>2
3#include <test/cpp/tensorexpr/test_base.h>4
5#include <torch/csrc/jit/ir/irparser.h>6#include <torch/csrc/jit/passes/subgraph_rewrite.h>7#include <torch/csrc/jit/passes/tensorexpr_fuser.h>8#include <torch/csrc/jit/runtime/custom_operator.h>9#include <torch/csrc/jit/tensorexpr/kernel.h>10
11#include <test/cpp/tensorexpr/test_utils.h>12#include <torch/csrc/jit/runtime/operator.h>13#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>14#include <torch/csrc/jit/tensorexpr/eval.h>15#include <torch/csrc/jit/tensorexpr/external_functions_registry.h>16#include <torch/csrc/jit/tensorexpr/ir.h>17#include <torch/csrc/jit/tensorexpr/ir_printer.h>18#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>19#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>20#include <torch/csrc/jit/tensorexpr/loopnest.h>21#include <torch/csrc/jit/tensorexpr/tensor.h>22
23#include <torch/csrc/jit/testing/file_check.h>24#include <torch/jit.h>25
26#include <ATen/NativeFunctions.h>27#include <ATen/core/dispatch/Dispatcher.h>28#include <ATen/native/xnnpack/OpContext.h>29
30namespace torch {31namespace jit {32using namespace torch::jit::tensorexpr;33
34TEST(ExternalCall, Conv1d_float) {35BufHandle Input("Input", {1, 100, 115}, kFloat);36BufHandle Weight("Weight", {100, 1, 7}, kFloat);37BufHandle Bias("Bias", {100}, kFloat);38BufHandle ResultBuf("Result", {1, 100, 115}, kFloat);39int64_t stride = 1;40int64_t pad = 3;41int64_t dilation = 1;42int64_t groups = 100;43
44Tensor Result = Tensor(45ResultBuf.node(),46ExternalCall::make(47ResultBuf,48"nnc_aten_conv1d",49{Input, Weight, Bias},50{stride, pad, dilation, groups}));51LoopNest l({Result});52l.prepareForCodegen();53l.simplify();54
55auto options = at::TensorOptions()56.dtype(at::kFloat)57.layout(at::kStrided)58.device(at::kCPU)59.requires_grad(false);60at::Tensor input = at::ones({1, 100, 115}, options) * 5.f;61at::Tensor weight = at::ones({100, 1, 7}, options) * 6.f;62at::Tensor bias = at::ones({100}, options) * 11.f;63at::Tensor ref =64at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups);65
66at::Tensor nnc_result;67std::vector<float> input_buf(1 * 100 * 115, 5.f);68std::vector<float> weight_buf(100 * 1 * 7, 6.f);69std::vector<float> bias_buf(100, 11.f);70std::vector<float> result_buf(1 * 100 * 115, -1.f);71
72#ifdef TORCH_ENABLE_LLVM73LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});74
75llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});76nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);77ASSERT_TRUE(at::allclose(nnc_result, ref));78#endif79
80SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});81
82ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});83nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);84ASSERT_TRUE(at::allclose(nnc_result, ref));85}
86
87TEST(ExternalCall, Conv1d_int) {88// A similar test, but now using kInt tensors89BufHandle Input("Input", {1, 100, 115}, kInt);90BufHandle Weight("Weight", {100, 1, 7}, kInt);91BufHandle Bias("Bias", {100}, kInt);92BufHandle ResultBuf("Result", {1, 100, 115}, kInt);93int64_t stride = 1;94int64_t pad = 3;95int64_t dilation = 1;96int64_t groups = 100;97
98Tensor Result = Tensor(99ResultBuf.node(),100ExternalCall::make(101ResultBuf,102"nnc_aten_conv1d",103{Input, Weight, Bias},104{stride, pad, dilation, groups}));105LoopNest l({Result});106l.prepareForCodegen();107l.simplify();108
109auto options = at::TensorOptions()110.dtype(at::kInt)111.layout(at::kStrided)112.device(at::kCPU)113.requires_grad(false);114at::Tensor input = at::ones({1, 100, 115}, options) * 5;115at::Tensor weight = at::ones({100, 1, 7}, options) * 6;116at::Tensor bias = at::ones({100}, options) * 11;117at::Tensor ref =118at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups);119
120at::Tensor nnc_result;121std::vector<int32_t> input_buf(1 * 100 * 115, 5);122std::vector<int32_t> weight_buf(100 * 1 * 7, 6);123std::vector<int32_t> bias_buf(100, 11);124std::vector<int32_t> result_buf(1 * 100 * 115, -1);125
126#ifdef TORCH_ENABLE_LLVM127LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});128
129llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});130nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);131ASSERT_TRUE(at::allclose(nnc_result, ref));132#endif133
134SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});135
136ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});137nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);138ASSERT_TRUE(at::allclose(nnc_result, ref));139}
140
141TEST(ExternalCall, Conv1d_nobias_noargs) {142BufHandle Input("Input", {1, 1, 115}, kFloat);143BufHandle Weight("Weight", {10, 1, 7}, kFloat);144BufHandle ResultBuf("Result", {1, 10, 109}, kFloat);145
146Tensor Result = Tensor(147ResultBuf.node(),148ExternalCall::make(ResultBuf, "nnc_aten_conv1d", {Input, Weight}, {}));149LoopNest l({Result});150l.prepareForCodegen();151l.simplify();152
153auto options = at::TensorOptions()154.dtype(at::kFloat)155.layout(at::kStrided)156.device(at::kCPU)157.requires_grad(false);158at::Tensor input = at::ones({1, 1, 115}, options) * 5.f;159at::Tensor weight = at::ones({10, 1, 7}, options) * 6.f;160at::Tensor ref = at::conv1d(input, weight);161
162at::Tensor nnc_result;163std::vector<float> input_buf(1 * 1 * 115, 5.f);164std::vector<float> weight_buf(10 * 1 * 7, 6.f);165std::vector<float> result_buf(1 * 10 * 109, -1.f);166
167#ifdef TORCH_ENABLE_LLVM168LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result});169
170llvm_codegen.call({input_buf, weight_buf, result_buf});171nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options);172ASSERT_TRUE(at::allclose(nnc_result, ref));173#endif174
175SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result});176
177ir_eval.call({input_buf, weight_buf, result_buf});178nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options);179ASSERT_TRUE(at::allclose(nnc_result, ref));180}
181
182TEST(ExternalCall, Conv2d_float) {183BufHandle Input("Input", {1, 3, 224, 224}, kFloat);184BufHandle Weight("Weight", {16, 3, 3, 3}, kFloat);185BufHandle Bias("Bias", {16}, kFloat);186BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);187int64_t stride = 2;188int64_t pad = 1;189int64_t dilation = 1;190int64_t groups = 1;191
192Tensor Result = Tensor(193ResultBuf.node(),194ExternalCall::make(195ResultBuf,196"nnc_aten_conv2d",197{Input, Weight, Bias},198{stride, stride, pad, pad, dilation, dilation, groups}));199LoopNest l({Result});200l.prepareForCodegen();201l.simplify();202
203auto options = at::TensorOptions()204.dtype(at::kFloat)205.layout(at::kStrided)206.device(at::kCPU)207.requires_grad(false);208at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5.f;209at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6.f;210at::Tensor bias = at::ones({16}, options) * 11.f;211at::Tensor ref = at::conv2d(212input,213weight,214bias,215{stride, stride},216{pad, pad},217{dilation, dilation},218groups);219
220at::Tensor nnc_result;221std::vector<float> input_buf(1 * 3 * 224 * 224, 5.f);222std::vector<float> weight_buf(16 * 3 * 3 * 3, 6.f);223std::vector<float> bias_buf(16, 11.f);224std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);225
226#ifdef TORCH_ENABLE_LLVM227LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});228
229llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});230nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);231ASSERT_TRUE(at::allclose(nnc_result, ref));232#endif233
234SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});235
236ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});237nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);238ASSERT_TRUE(at::allclose(nnc_result, ref));239}
240
241TEST(ExternalCall, Conv2d_int) {242// A similar test, but now using kInt tensors243
244BufHandle Input("Input", {1, 3, 224, 224}, kInt);245BufHandle Weight("Weight", {16, 3, 3, 3}, kInt);246BufHandle Bias("Bias", {16}, kInt);247BufHandle ResultBuf("Result", {1, 16, 112, 112}, kInt);248int64_t stride = 2;249int64_t pad = 1;250int64_t dilation = 1;251int64_t groups = 1;252
253Tensor Result = Tensor(254ResultBuf.node(),255ExternalCall::make(256ResultBuf,257"nnc_aten_conv2d",258{Input, Weight, Bias},259{stride, stride, pad, pad, dilation, dilation, groups}));260LoopNest l({Result});261l.prepareForCodegen();262l.simplify();263
264auto options = at::TensorOptions()265.dtype(at::kInt)266.layout(at::kStrided)267.device(at::kCPU)268.requires_grad(false);269at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5;270at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6;271at::Tensor bias = at::ones({16}, options) * 11;272at::Tensor ref = at::conv2d(273input,274weight,275bias,276{stride, stride},277{pad, pad},278{dilation, dilation},279groups);280
281at::Tensor nnc_result;282std::vector<int32_t> input_buf(1 * 3 * 224 * 224, 5);283std::vector<int32_t> weight_buf(16 * 3 * 3 * 3, 6);284std::vector<int32_t> bias_buf(16, 11);285std::vector<int32_t> result_buf(1 * 16 * 112 * 112, -1);286
287#ifdef TORCH_ENABLE_LLVM288LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});289
290llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});291nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);292ASSERT_TRUE(at::allclose(nnc_result, ref));293#endif294
295SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});296
297ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});298nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);299ASSERT_TRUE(at::allclose(nnc_result, ref));300}
301
302TEST(ExternalCall, Conv2d_nobias_noargs) {303BufHandle Input("Input", {1, 16, 112, 112}, kFloat);304BufHandle Weight("Weight", {16, 16, 1, 1}, kFloat);305BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);306
307Tensor Result = Tensor(308ResultBuf.node(),309ExternalCall::make(ResultBuf, "nnc_aten_conv2d", {Input, Weight}, {}));310LoopNest l({Result});311l.prepareForCodegen();312l.simplify();313
314auto options = at::TensorOptions()315.dtype(at::kFloat)316.layout(at::kStrided)317.device(at::kCPU)318.requires_grad(false);319at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f;320at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f;321at::Tensor ref = at::conv2d(input, weight);322
323at::Tensor nnc_result;324std::vector<float> input_buf(1 * 16 * 112 * 112, 5.f);325std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f);326std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);327
328#ifdef TORCH_ENABLE_LLVM329LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result});330
331llvm_codegen.call({input_buf, weight_buf, result_buf});332nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);333ASSERT_TRUE(at::allclose(nnc_result, ref));334#endif335
336SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result});337
338ir_eval.call({input_buf, weight_buf, result_buf});339nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);340ASSERT_TRUE(at::allclose(nnc_result, ref));341}
342
343TEST(ExternalCall, Addmm_float) {344BufHandle Input("Input", {100, 300}, kFloat);345BufHandle Mat1("Mat1", {100, 200}, kFloat);346BufHandle Mat2("Mat2", {200, 300}, kFloat);347BufHandle ResultBuf("Result", {100, 300}, kFloat);348int64_t beta = 2;349int64_t alpha = 2;350
351Tensor Result = Tensor(352ResultBuf.node(),353ExternalCall::make(354ResultBuf, "nnc_aten_addmm", {Input, Mat1, Mat2}, {beta, alpha}));355LoopNest l({Result});356l.prepareForCodegen();357l.simplify();358
359auto options = at::TensorOptions()360.dtype(at::kFloat)361.layout(at::kStrided)362.device(at::kCPU)363.requires_grad(false);364at::Tensor input = at::ones({100, 300}, options) * 5.f;365at::Tensor mat1 = at::ones({100, 200}, options) * 6.f;366at::Tensor mat2 = at::ones({200, 300}, options) * 11.f;367at::Tensor ref = at::addmm(input, mat1, mat2, beta, alpha);368
369at::Tensor nnc_result;370std::vector<float> input_buf(100 * 300, 5.f);371std::vector<float> mat1_buf(100 * 200, 6.f);372std::vector<float> mat2_buf(200 * 300, 11.f);373std::vector<float> result_buf(100 * 300, -1.f);374
375#ifdef TORCH_ENABLE_LLVM376LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Mat1, Mat2, Result});377
378llvm_codegen.call({input_buf, mat1_buf, mat2_buf, result_buf});379nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);380ASSERT_TRUE(at::allclose(nnc_result, ref));381#endif382
383SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Mat1, Mat2, Result});384
385ir_eval.call({input_buf, mat1_buf, mat2_buf, result_buf});386nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);387ASSERT_TRUE(at::allclose(nnc_result, ref));388}
389
390TEST(ExternalCall, Embedding) {391BufHandle Weight("Weight", {256, 100}, kFloat);392BufHandle Indices("Indices", {1, 115}, kLong);393BufHandle ResultBuf("Result", {1, 115, 100}, kFloat);394int64_t padding_idx = -1;395bool scale_grad_by_freq = false;396bool sparse = false;397
398Tensor Result = Tensor(399ResultBuf.node(),400ExternalCall::make(401ResultBuf,402"nnc_aten_embedding",403{Weight, Indices},404{padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse}));405LoopNest l({Result});406l.prepareForCodegen();407l.simplify();408
409auto options = at::TensorOptions()410.layout(at::kStrided)411.device(at::kCPU)412.requires_grad(false);413
414at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f;415at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6;416at::Tensor ref =417at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);418
419at::Tensor nnc_result;420std::vector<float> weight_buf(256 * 100, 5.f);421std::vector<int64_t> indices_buf(1 * 115, 6);422std::vector<float> result_buf(1 * 115 * 100, -1.f);423
424#ifdef TORCH_ENABLE_LLVM425LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result});426
427llvm_codegen.call({weight_buf, indices_buf, result_buf});428nnc_result = at::from_blob(429result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));430ASSERT_TRUE(at::allclose(nnc_result, ref));431#endif432
433SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result});434
435ir_eval.call({weight_buf, indices_buf, result_buf});436nnc_result = at::from_blob(437result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));438ASSERT_TRUE(at::allclose(nnc_result, ref));439}
440
441TEST(ExternalCall, MaxReduction) {442BufHandle Input("Input", {1, 115, 152}, kFloat);443BufHandle ResultBuf("Result", {1, 152}, kFloat);444int64_t dim = 1;445bool keep_dim = false;446
447Tensor Result = Tensor(448ResultBuf.node(),449ExternalCall::make(450ResultBuf, "nnc_aten_max_red", {Input}, {dim, (int64_t)keep_dim}));451LoopNest l({Result});452l.prepareForCodegen();453l.simplify();454
455auto options = at::TensorOptions()456.dtype(at::kFloat)457.layout(at::kStrided)458.device(at::kCPU)459.requires_grad(false);460
461at::Tensor input = at::ones({1, 115, 152}, options) * 5.f;462at::Tensor ref = std::get<0>(at::max(input, dim, keep_dim));463
464at::Tensor nnc_result;465std::vector<float> input_buf(1 * 115 * 152, 5.f);466std::vector<float> result_buf(1 * 152, -1.f);467
468#ifdef TORCH_ENABLE_LLVM469LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Result});470
471llvm_codegen.call({input_buf, result_buf});472nnc_result = at::from_blob(result_buf.data(), {1, 152}, options);473ASSERT_TRUE(at::allclose(nnc_result, ref));474#endif475
476SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Result});477
478ir_eval.call({input_buf, result_buf});479nnc_result = at::from_blob(result_buf.data(), {1, 152}, options);480ASSERT_TRUE(at::allclose(nnc_result, ref));481}
482
483#ifdef USE_XNNPACK484
485TEST(ExternalCall, Prepacked_Linear_float) {486using namespace at::native::xnnpack;487
488BufHandle Input("Input", {100, 200}, kFloat);489BufHandle ResultBuf("Result", {100, 300}, kFloat);490
491// Calculate reference result using at::linear.492auto options = at::TensorOptions()493.dtype(at::kFloat)494.layout(at::kStrided)495.device(at::kCPU)496.requires_grad(false);497at::Tensor input =498at::linspace(-10.0, 10.0, 100 * 200, options).resize_({100, 200});499at::Tensor weight =500at::linspace(-10.0, 10.0, 300 * 200, options).resize_({300, 200});501at::Tensor bias = at::linspace(-10.0, 10.0, 300, options);502at::Tensor ref = at::linear(input, weight, bias);503
504// Create prepacked xnnpack context object.505auto linear_clamp_prepack_op =506c10::Dispatcher::singleton()507.findSchemaOrThrow("prepacked::linear_clamp_prepack", "")508.typed<c10::intrusive_ptr<LinearOpContext>(509at::Tensor,510std::optional<at::Tensor>,511const std::optional<at::Scalar>&,512const std::optional<at::Scalar>&)>();513auto prepacked = linear_clamp_prepack_op.call(514weight, bias, std::optional<at::Scalar>(), std::optional<at::Scalar>());515
516BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat);517Tensor Result = Tensor(518ResultBuf.node(),519ExternalCall::make(520ResultBuf,521"nnc_prepacked_linear_clamp_run",522{Input, DummyPrepacked},523{}));524LoopNest l({Result});525l.prepareForCodegen();526l.simplify();527
528at::Tensor nnc_result;529std::vector<float> input_buf(530input.data_ptr<float>(), input.data_ptr<float>() + 100 * 200);531std::vector<float> result_buf(100 * 300, -1.f);532
533#ifdef TORCH_ENABLE_LLVM534LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result});535
536llvm_codegen.call({input_buf, prepacked.get(), result_buf});537nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);538ASSERT_TRUE(at::allclose(nnc_result, ref));539#endif540
541SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result});542
543ir_eval.call({input_buf, prepacked.get(), result_buf});544nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);545ASSERT_TRUE(at::allclose(nnc_result, ref));546}
547
548TEST(ExternalCall, Prepacked_Conv2d_float) {549using namespace at::native::xnnpack;550
551BufHandle Input("Input", {1, 3, 224, 224}, kFloat);552BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);553int64_t stride = 2;554int64_t pad = 1;555int64_t dilation = 1;556int64_t groups = 1;557
558// Calculate reference result using at::conv2d.559auto options = at::TensorOptions()560.dtype(at::kFloat)561.layout(at::kStrided)562.device(at::kCPU)563.requires_grad(false);564at::Tensor input = at::linspace(-10.0, 10.0, 1 * 3 * 224 * 224, options)565.resize_({1, 3, 224, 224});566at::Tensor weight =567at::linspace(-10.0, 10.0, 16 * 3 * 3 * 3, options).resize_({16, 3, 3, 3});568at::Tensor bias = at::linspace(-10.0, 10.0, 16, options);569at::Tensor ref = at::conv2d(570input,571weight,572bias,573{stride, stride},574{pad, pad},575{dilation, dilation},576groups);577
578// Create prepacked xnnpack context object.579auto conv2d_clamp_prepack_op =580c10::Dispatcher::singleton()581.findSchemaOrThrow("prepacked::conv2d_clamp_prepack", "")582.typed<c10::intrusive_ptr<Conv2dOpContext>(583at::Tensor,584std::optional<at::Tensor>,585std::vector<int64_t>,586std::vector<int64_t>,587std::vector<int64_t>,588int64_t,589const std::optional<at::Scalar>&,590const std::optional<at::Scalar>&)>();591auto prepacked = conv2d_clamp_prepack_op.call(592weight,593bias,594{stride, stride},595{pad, pad},596{dilation, dilation},597groups,598std::optional<at::Scalar>(),599std::optional<at::Scalar>());600
601BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat);602Tensor Result = Tensor(603ResultBuf.node(),604ExternalCall::make(605ResultBuf,606"nnc_prepacked_conv2d_clamp_run",607{Input, DummyPrepacked},608{}));609LoopNest l({Result});610l.prepareForCodegen();611l.simplify();612
613at::Tensor nnc_result;614std::vector<float> input_buf(615input.data_ptr<float>(), input.data_ptr<float>() + 1 * 3 * 224 * 224);616std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);617
618#ifdef TORCH_ENABLE_LLVM619LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result});620
621llvm_codegen.call({input_buf, prepacked.get(), result_buf});622nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);623ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03));624#endif625
626SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result});627
628ir_eval.call({input_buf, prepacked.get(), result_buf});629nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);630ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03));631}
632
633#endif // USE_XNNPACK634
635TEST(ExternalCall, BinaryFloat) {636using TensorFunc = std::function<at::Tensor(at::Tensor, at::Tensor)>;637using Test = std::tuple<638std::vector<int64_t>,639std::vector<int64_t>,640std::vector<int64_t>,641TensorFunc,642std::string>;643std::vector<Test> tests = {};644tests.push_back(645Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul"});646tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv"});647tests.push_back(648Test{{100, 200}, {200, 300}, {100, 300}, at::mm, "nnc_aten_mm"});649for (auto curTest : tests) {650auto [aShape, bShape, resShape, torchFunc, externCallName] = curTest;651auto toExprHandleVec = [](std::vector<int64_t> v) {652auto intV = std::vector<int>(v.begin(), v.end());653return std::vector<ExprHandle>(intV.begin(), intV.end());654};655BufHandle A("A", toExprHandleVec(aShape), kFloat);656BufHandle B("B", toExprHandleVec(bShape), kFloat);657BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);658
659Tensor Result = Tensor(660ResultBuf.node(),661ExternalCall::make(ResultBuf, externCallName, {A, B}, {}));662LoopNest l({Result});663l.prepareForCodegen();664l.simplify();665
666auto options = at::TensorOptions()667.dtype(at::kFloat)668.layout(at::kStrided)669.device(at::kCPU)670.requires_grad(false);671at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;672at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f;673at::Tensor ref = torchFunc(a, b);674
675auto prod = [](std::vector<int64_t> v) {676// NOLINTNEXTLINE(modernize-use-transparent-functors)677return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());678};679
680at::Tensor nnc_result;681std::vector<float> a_buf(prod(aShape), 5.f);682std::vector<float> b_buf(prod(bShape), 6.f);683std::vector<float> result_buf(prod(resShape), -1.f);684
685#ifdef TORCH_ENABLE_LLVM686LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result});687
688llvm_codegen.call({a_buf, b_buf, result_buf});689nnc_result =690at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);691ASSERT_TRUE(at::allclose(nnc_result, ref));692#endif693
694SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result});695ir_eval.call({a_buf, b_buf, result_buf});696nnc_result =697at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);698ASSERT_TRUE(at::allclose(nnc_result, ref));699}700}
701
702TEST(ExternalCall, UnaryFloat) {703using TensorFunc = std::function<at::Tensor(at::Tensor)>;704auto toExprHandleVec = [](std::vector<int64_t> v) {705auto intV = std::vector<int>(v.begin(), v.end());706return std::vector<ExprHandle>(intV.begin(), intV.end());707};708using Test = std::tuple<709std::vector<int64_t>,710std::vector<int64_t>,711TensorFunc,712std::string,713std::vector<ExprHandle>>;714std::vector<Test> tests = {};715tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)716{1, 64, 8, 9},717{1, 64, 5, 7},718[](at::Tensor x) {719return at::adaptive_avg_pool2d(x, {5, 7});720},721"nnc_aten_adaptive_avg_pool2d",722toExprHandleVec({5, 7})});723tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)724{100, 200},725{100},726[](at::Tensor x) { return at::mean(x, {1}); },727"nnc_aten_mean",728toExprHandleVec({1, /*keepdim=*/0})});729for (auto curTest : tests) {730auto [aShape, resShape, torchFunc, externCallName, externCallArgs] =731curTest;732BufHandle A("A", toExprHandleVec(aShape), kFloat);733BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);734
735Tensor Result = Tensor(736ResultBuf.node(),737ExternalCall::make(ResultBuf, externCallName, {A}, externCallArgs));738LoopNest l({Result});739l.prepareForCodegen();740l.simplify();741
742auto options = at::TensorOptions()743.dtype(at::kFloat)744.layout(at::kStrided)745.device(at::kCPU)746.requires_grad(false);747at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;748at::Tensor ref = torchFunc(a);749
750auto prod = [](std::vector<int64_t> v) {751// NOLINTNEXTLINE(modernize-use-transparent-functors)752return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());753};754
755at::Tensor nnc_result;756std::vector<float> a_buf(prod(aShape), 5.f);757std::vector<float> result_buf(prod(resShape), -1.f);758
759#ifdef TORCH_ENABLE_LLVM760LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result});761
762llvm_codegen.call({a_buf, result_buf});763nnc_result =764at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);765ASSERT_TRUE(at::allclose(nnc_result, ref));766#endif767
768SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result});769ir_eval.call({a_buf, result_buf});770nnc_result =771at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);772ASSERT_TRUE(at::allclose(nnc_result, ref));773}774}
775
776TEST(ExternalCall, ComputeInterop) {777// This test verifies that Tensors using external calls can be used by and can778// use Tensors built with Compute API.779
780BufHandle ConvResultBuf("ConvResult", {1, 16, 32, 32}, kFloat);781BufHandle MatmulResultBuf("MatmulResult", {1, 16, 32, 32}, kFloat);782
783Tensor Input = Compute(784"Input",785{1, 16, 32, 32},786[&](const VarHandle& n,787const VarHandle& c,788const VarHandle& h,789const VarHandle& w) { return FloatImm::make(5.0f); });790Tensor Weight = Compute(791"Weight",792{16, 16, 1, 1},793[&](const VarHandle& n,794const VarHandle& c,795const VarHandle& h,796const VarHandle& w) { return FloatImm::make(6.0f); });797
798Tensor ConvResult = Tensor(799ConvResultBuf.node(),800ExternalCall::make(801ConvResultBuf,802"nnc_aten_conv2d",803{BufHandle(Input.buf()), BufHandle(Weight.buf())},804{}));805Tensor MatmulResult = Tensor(806MatmulResultBuf.node(),807ExternalCall::make(808MatmulResultBuf,809"nnc_aten_matmul",810{BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())},811{}));812Tensor Result = Compute(813"Result",814{1, 16, 32, 32},815[&](const VarHandle& n,816const VarHandle& c,817const VarHandle& h,818const VarHandle& w) {819return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w);820});821
822LoopNest l({Input, Weight, ConvResult, MatmulResult, Result});823
824// Inlining should not inline anything here since all Bufs are either defined825// or used in ExternalCalls - we run it just for testing826l.inlineIntermediateBufs(true);827
828l.prepareForCodegen();829l.simplify();830
831auto options = at::TensorOptions()832.dtype(at::kFloat)833.layout(at::kStrided)834.device(at::kCPU)835.requires_grad(false);836at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f;837at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f;838at::Tensor t = at::conv2d(input, weight);839at::Tensor t2 = at::matmul(t, t);840at::Tensor ref = t + t2;841
842at::Tensor nnc_result;843std::vector<float> input_buf(1 * 16 * 32 * 32, 5.f);844std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f);845std::vector<float> conv_result_buf(1 * 16 * 32 * 32, -1.f);846std::vector<float> matmul_result_buf(1 * 16 * 32 * 32, -1.f);847std::vector<float> result_buf(1 * 16 * 32 * 32, -1.f);848
849#ifdef TORCH_ENABLE_LLVM850LLVMCodeGen llvm_codegen(851l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result});852
853llvm_codegen.call(854{input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf});855nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options);856ASSERT_TRUE(at::allclose(nnc_result, ref));857#endif858
859SimpleIREvaluator ir_eval(860l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result});861
862ir_eval.call(863{input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf});864nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options);865ASSERT_TRUE(at::allclose(nnc_result, ref));866}
867
868TEST(ExternalCall, Inlining) {869// This test verifies that Tensors using external calls can be used by and870// can use Tensors built with Compute API.871
872BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat);873
874Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {875return FloatImm::make(5.0f);876});877Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {878return FloatImm::make(4.0f);879});880Tensor MatmulResult = Tensor(881MatmulResultBuf.node(),882ExternalCall::make(883MatmulResultBuf,884"nnc_aten_matmul",885{BufHandle(A.buf()), BufHandle(B.buf())},886{}));887Tensor Result =888Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {889return MatmulResult.load(i, j) + FloatImm::make(3.0f);890});891
892StmtPtr root_stmt = alloc<torch::jit::tensorexpr::Block>(std::vector<StmtPtr>(893{A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()}));894LoopNest l(root_stmt, {Result.buf()});895
896// Inlining should not inline anything here since all Bufs are either897// defined or used in ExternalCalls898l.inlineIntermediateBufs(false);899
900l.prepareForCodegen();901l.simplify();902
903auto options = at::TensorOptions()904.dtype(at::kFloat)905.layout(at::kStrided)906.device(at::kCPU)907.requires_grad(false);908at::Tensor a = at::ones({8, 8}, options) * 5.f;909at::Tensor b = at::ones({8, 8}, options) * 4.f;910at::Tensor t = at::matmul(a, b);911at::Tensor ref = t + 3.f;912
913at::Tensor nnc_result;914std::vector<float> result_buf(8 * 8);915
916#ifdef TORCH_ENABLE_LLVM917LLVMCodeGen llvm_codegen(l.root_stmt(), {Result});918
919llvm_codegen.call({result_buf});920nnc_result = at::from_blob(result_buf.data(), {8, 8}, options);921ASSERT_TRUE(at::allclose(nnc_result, ref));922#endif923
924SimpleIREvaluator ir_eval(l.root_stmt(), {Result});925
926ir_eval.call({result_buf});927nnc_result = at::from_blob(result_buf.data(), {8, 8}, options);928ASSERT_TRUE(at::allclose(nnc_result, ref));929}
930
931TEST(ExternalCall, JitCustomFusionOp) {932const char* custom_op_schema_literal =933"nnc_custom::add_mul(Tensor a, Tensor b, Tensor c) -> Tensor";934const char* external_func_name = "nnc_add_mul";935
936auto add_mul_lowering_func =937[external_func_name](938const std::vector<torch::jit::tensorexpr::ArgValue>& inputs,939const std::vector<torch::jit::tensorexpr::ExprHandle>& output_shape,940const std::vector<torch::jit::tensorexpr::ExprHandle>& output_strides,941const std::optional<torch::jit::tensorexpr::ScalarType>& output_type,942at::Device device) {943auto output_dtype = Dtype(*output_type);944torch::jit::tensorexpr::BufHandle result_buf(945"nnc_add_mul_res_buf", output_shape, output_dtype);946const torch::jit::tensorexpr::BufHandle& a =947std::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);948const torch::jit::tensorexpr::BufHandle& b =949std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);950const torch::jit::tensorexpr::BufHandle& c =951std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);952torch::jit::tensorexpr::StmtPtr s =953torch::jit::tensorexpr::ExternalCall::make(954result_buf, external_func_name, {a, b, c}, {});955return Tensor(result_buf.node(), s);956};957
958auto add_mul_external_func = [](int64_t bufs_num,959void** buf_data,960int64_t* buf_ranks,961int64_t* buf_dims,962int64_t* buf_strides,963int8_t* buf_dtypes,964int64_t args_num,965int64_t* extra_args) {};966
967torch::jit::RegisterOperators reg({Operator(968custom_op_schema_literal,969[](const Node* node) -> Operation {970return [](Stack& _stack) {971auto a = std::move(peek(_stack, 0, 3)).toTensor();972auto b = std::move(peek(_stack, 1, 3)).toTensor();973auto c = std::move(peek(_stack, 2, 3)).toTensor();974drop(_stack, 3);975auto result = (a + b) * c;976pack(_stack, std::move(result));977return 0;978};979},980c10::AliasAnalysisKind::FROM_SCHEMA)});981
982auto& custom_operator_set = torch::jit::tensorexpr::getCustomOperatorSet();983custom_operator_set.insert({custom_op_schema_literal});984
985auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry();986te_lowering_registry.insert(987parseSchema(custom_op_schema_literal), add_mul_lowering_func);988
989auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry();990te_nnc_func_registry[external_func_name] = add_mul_external_func;991
992std::string graph_string = R"IR(993graph(%a : Float(10, 20, strides=[20, 1], device=cpu),
994%b : Float(10, 20, strides=[20, 1], device=cpu),
995%c : Float(10, 20, strides=[20, 1], device=cpu)):
996%res : Float(10, 20, strides=[20, 1], device=cpu) = nnc_custom::add_mul(%a, %b, %c)
997return (%res))IR";998
999auto graph = std::make_shared<Graph>();1000torch::jit::parseIR(graph_string, graph.get());1001
1002std::string shape_compute_python_string = R"PY(1003def computOutput(a: List[int], b: List[int], c: List[int]):
1004expandedSizes: List[int] = []
1005dimsA = len(a)
1006dimsB = len(b)
1007dimsC = len(c)
1008ndim = max(dimsA, dimsB, dimsC)
1009for i in range(ndim):
1010offset = ndim - 1 - i
1011dimA = dimsA - 1 - offset
1012dimB = dimsB - 1 - offset
1013dimC = dimsC - 1 - offset
1014sizeA = a[dimA] if (dimA >= 0) else 1
1015sizeB = b[dimB] if (dimB >= 0) else 1
1016sizeC = a[dimC] if (dimC >= 0) else 1
1017
1018if sizeA != sizeB and sizeB != sizeC and sizeA != 1 and sizeB != 1 and sizeC != 1:
1019# TODO: only assertion error is bound in C++ compilation right now
1020raise AssertionError(
1021"The size of tensor a {} must match the size of tensor b ("
1022"{} and c {}) at non-singleton dimension {}".format(sizeA, sizeB, sizeC, i)
1023)
1024
1025expandedSizes.append(max(sizeA, sizeB, sizeC))
1026
1027return expandedSizes
1028)PY";1029auto cu_ptr = torch::jit::compile(shape_compute_python_string);1030torch::jit::GraphFunction* gf =1031(torch::jit::GraphFunction*)&cu_ptr->get_function("computOutput");1032ASSERT_TRUE(gf);1033
1034#ifdef TORCH_ENABLE_LLVM1035auto static_graph_case = graph->copy();1036FuseTensorExprs(static_graph_case, 1);1037torch::jit::testing::FileCheck()1038.check("prim::TensorExprGroup_")1039->check("nnc_custom::add_mul")1040->run(*static_graph_case);1041
1042auto dynamic_graph_case = graph->copy();1043auto custom_op = torch::jit::getOperatorForLiteral(custom_op_schema_literal);1044ASSERT_TRUE(custom_op);1045torch::jit::RegisterShapeComputeGraphForSchema(1046custom_op->schema(), gf->graph());1047FuseTensorExprs(dynamic_graph_case, 1, false, true);1048torch::jit::testing::FileCheck()1049.check("prim::TensorExprGroup_")1050->check("nnc_custom::add_mul")1051->run(*dynamic_graph_case);1052#else1053torch::jit::testing::FileCheck().check("nnc_custom::add_mul")->run(*graph);1054#endif1055}
1056
1057} // namespace jit1058} // namespace torch1059