pytorch
402 строки · 13.2 Кб
1#include <gtest/gtest.h>2
3#include <test/cpp/tensorexpr/test_base.h>4#include <torch/csrc/jit/codegen/fuser/interface.h>5#include <torch/csrc/jit/ir/ir.h>6#include <torch/csrc/jit/ir/irparser.h>7#include <torch/csrc/jit/passes/tensorexpr_fuser.h>8#include <torch/csrc/jit/runtime/interpreter.h>9#include <torch/csrc/jit/testing/file_check.h>10#include <sstream>11
12namespace torch {13namespace jit {14
15using namespace torch::jit::tensorexpr;16
17struct WithCPUFuser {18WithCPUFuser(bool val = true) : cpuFuserEnabled(canFuseOnCPU()) {19overrideCanFuseOnCPU(val);20}21
22~WithCPUFuser() {23overrideCanFuseOnCPU(cpuFuserEnabled);24}25
26bool cpuFuserEnabled;27};28
29TEST(TEFuserPass, FuserPass_1) {30WithCPUFuser cf;31const auto graph_string = R"IR(32graph(%0 : Float(128, strides=[1], device=cpu),
33%1 : Float(128, strides=[1], device=cpu)):
34%12 : int = prim::Constant[value=1]()
35%2.1 : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
36%2 : Float(128, strides=[1], device=cpu) = aten::mul(%2.1, %1)
37%3 : Float(128, strides=[1], device=cpu) = aten::add_(%2, %1, %12)
38%4 : Float(128, strides=[1], device=cpu) = aten::mul(%2, %1)
39%5 : Float(128, strides=[1], device=cpu) = aten::add(%2, %4, %12)
40return (%5))IR";41auto g = std::make_shared<Graph>();42torch::jit::parseIR(graph_string, g.get());43
44g->lint();45FuseTensorExprs(g);46
47// We should not be able to fuse across the in-place operation here.48testing::FileCheck()49.check("prim::TensorExprGroup_")50->check("aten::add_")51->check("prim::TensorExprGroup_")52->run(*g);53}
54
55TEST(TEFuserPass, FuserPass_2) {56WithCPUFuser cf;57const auto graph_string = R"IR(58graph(%0 : Float(128, strides=[1], device=cpu),
59%1 : Float(128, strides=[1], device=cpu)):
60%12 : int = prim::Constant[value=1]()
61%a : Float(128, strides=[1], device=cpu) = aten::mul(%0, %1)
62%b : Float(128, strides=[1], device=cpu) = aten::add(%0, %1, %12)
63%c : Float(128, strides=[1], device=cpu) = aten::add_(%b, %1, %12)
64%d : Float(128, strides=[1], device=cpu) = aten::mul(%c, %a)
65return (%d))IR";66auto g = std::make_shared<Graph>();67torch::jit::parseIR(graph_string, g.get());68
69g->lint();70FuseTensorExprs(g);71
72// We should not be able to fuse across the in-place operation here.73testing::FileCheck()74.check("aten::add_")75->check("prim::TensorExprGroup_0")76->run(*g);77}
78
79TEST(TEFuserPass, FuserPass_3) {80WithCPUFuser cf;81const auto graph_string = R"IR(82graph(%x : Float(128, strides=[1], device=cpu),
83%y : Float(128, strides=[1], device=cpu)):
84%r : Float(128, strides=[1], device=cpu) = aten::mul(%x, %y)
85return (%r))IR";86{87auto g = std::make_shared<Graph>();88torch::jit::parseIR(graph_string, g.get());89
90g->lint();91FuseTensorExprs(g, /* min_group_size= */ 2);92
93// We should not create a fusion group since its size would be too small94testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);95}96{97auto g = std::make_shared<Graph>();98torch::jit::parseIR(graph_string, g.get());99
100g->lint();101FuseTensorExprs(g, /* min_group_size= */ 1);102
103// We should create a fusion group since its size is above the threshold104testing::FileCheck().check("prim::TensorExprGroup")->run(*g);105}106}
107
108TEST(TEFuserPass, FuserPass_0DimInput) {109WithCPUFuser cf;110const auto graph_string = R"IR(111graph(%x : Float(device=cpu),
112%y : Float(device=cpu)):
113%one : int = prim::Constant[value=1]()
114%a : Float(device=cpu) = aten::mul(%x, %y)
115%b : Float(device=cpu) = aten::add(%x, %a, %one)
116return (%b))IR";117auto g = std::make_shared<Graph>();118torch::jit::parseIR(graph_string, g.get());119
120g->lint();121FuseTensorExprs(g);122
123// We should fuse 0-dim tensors too124testing::FileCheck().check("prim::TensorExprGroup")->run(*g);125}
126
127TEST(TEFuserPass, FuserPass_UnfusibleDevice) {128WithCPUFuser cf(false);129const auto graph_string = R"IR(130graph(%x : Float(10, strides=[1], device=cpu),
131%y : Float(10, strides=[1], device=cpu)):
132%a : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
133return (%a))IR";134auto g = std::make_shared<Graph>();135torch::jit::parseIR(graph_string, g.get());136
137g->lint();138FuseTensorExprs(g, /* min_group_size= */ 1);139
140// Test that we're not starting fusion groups from nodes with unfusible device141testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);142}
143
144TEST(TEFuserPass, FuserPass_UnknownShapes) {145WithCPUFuser cf;146const auto graph_string = R"IR(147graph(%x : Tensor,
148%y : Tensor):
149%a : Tensor = aten::mul(%x, %y)
150%b : Tensor = aten::mul(%x, %a)
151return (%b))IR";152auto g = std::make_shared<Graph>();153torch::jit::parseIR(graph_string, g.get());154
155g->lint();156FuseTensorExprs(g);157
158// Test that we're not generating fusion groups when shapes are not known159testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);160}
161
162TEST(TEFuserPass, FuserPass_Multidevice) {163{164WithCPUFuser cf;165const auto graph_string = R"IR(166graph(%x : Float(10, strides=[1], device=cpu),
167%y : Float(20, strides=[1], device=cpu),
168%z : Float(30, strides=[1], device=cpu)):
169%dim : int = prim::Constant[value=0]()
170%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
171%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
172return (%cat))IR";173auto g = std::make_shared<Graph>();174torch::jit::parseIR(graph_string, g.get());175
176g->lint();177FuseTensorExprs(g, /* min_group_size= */ 1);178
179// We should be able to fuse this180testing::FileCheck().check("prim::TensorExprGroup")->run(*g);181}182{183WithCPUFuser cf;184const auto graph_string = R"IR(185graph(%x : Float(10, strides=[1], device=cpu),
186%y : Float(20, strides=[1], device=cuda:0),
187%z : Float(30, strides=[1], device=cpu)):
188%dim : int = prim::Constant[value=0]()
189%xyz_list : Tensor[] = prim::ListConstruct(%x, %y, %z)
190%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xyz_list, %dim)
191return (%cat))IR";192auto g = std::make_shared<Graph>();193torch::jit::parseIR(graph_string, g.get());194
195g->lint();196FuseTensorExprs(g, /* min_group_size= */ 1);197
198// We should not fuse this aten::cat since its inputs are from different199// devices200testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);201}202{203WithCPUFuser cf;204const auto graph_string = R"IR(205graph(%x : Float(10, strides=[1], device=cpu),
206%y : Float(20, strides=[1], device=cpu),
207%z : Float(10, strides=[1], device=cuda:0)):
208%dim : int = prim::Constant[value=0]()
209%xy_list : Tensor[] = prim::ListConstruct(%x, %y)
210%xy_cat : Float(30, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
211%r : Float(30, strides=[1], device=cpu) = aten::mul(%xy_cat, %z)
212return (%r))IR";213auto g = std::make_shared<Graph>();214torch::jit::parseIR(graph_string, g.get());215
216g->lint();217FuseTensorExprs(g, /* min_group_size= */ 2);218
219// Test that we check device before merging one node (cat) into another220// (mul)221testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);222}223{224WithCPUFuser cf;225const auto graph_string = R"IR(226graph(%x : Float(10, strides=[1], device=cpu),
227%y : Float(20, strides=[1], device=cpu),
228%z : Float(10, strides=[1], device=cuda:0)):
229%z2 : Tensor = aten::mul(%z, %z)
230%dim : int = prim::Constant[value=0]()
231%xy_list : Tensor[] = prim::ListConstruct(%x, %y, %z2)
232%cat : Float(60, strides=[1], device=cpu) = aten::cat(%xy_list, %dim)
233return (%cat))IR";234auto g = std::make_shared<Graph>();235torch::jit::parseIR(graph_string, g.get());236
237g->lint();238FuseTensorExprs(g, /* min_group_size= */ 2);239
240// Test that we check device before merging one node (mul) into another241// (cat)242testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);243}244{245WithCPUFuser cf;246const auto graph_string = R"IR(247graph(%x : Float(10, strides=[1], device=cpu),
248%y : Float(20, strides=[1], device=cuda:0)):
249%r : Float(10, strides=[1], device=cpu) = aten::mul(%x, %y)
250return (%r))IR";251auto g = std::make_shared<Graph>();252torch::jit::parseIR(graph_string, g.get());253
254g->lint();255FuseTensorExprs(g, /* min_group_size= */ 1);256
257// We should not fuse this graph since its inputs are from different devices258testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);259}260{261WithCPUFuser cf;262const auto graph_string = R"IR(263graph(%x : Float(10, strides=[1], device=cuda:0),
264%y : Float(20, strides=[1], device=cuda:1),
265%z : Float(20, strides=[1], device=cpu)):
266%x2 : Float(10, strides=[1], device=cpu) = aten::mul(%x, %x)
267%y2 : Float(10, strides=[1], device=cpu) = aten::mul(%y, %y)
268%z2 : Float(10, strides=[1], device=cpu) = aten::mul(%z, %z)
269return (%x2, %y2, %z2))IR";270auto g = std::make_shared<Graph>();271torch::jit::parseIR(graph_string, g.get());272
273g->lint();274FuseTensorExprs(g, /* min_group_size= */ 2);275
276// We should not fuse these two computations since they use different277// devices278testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);279}280}
281
282TEST(TEFuserPass, FuserPass_MergeGroups) {283WithCPUFuser cf;284const auto graph_string = R"IR(285graph(%a : Float(128, strides=[1], device=cpu),
286%b : Float(128, strides=[1], device=cpu)):
287%x : Float(128, strides=[1], device=cpu) = aten::mul(%a, %a)
288%y : Float(128, strides=[1], device=cpu) = aten::mul(%b, %b)
289return (%x, %y))IR";290auto g = std::make_shared<Graph>();291torch::jit::parseIR(graph_string, g.get());292
293g->lint();294FuseTensorExprs(g, /* min_group_size= */ 1);295
296// The %x and %y computations are completely independent and yet we should put297// them into a single fusion group rather than having two separate ones.298testing::FileCheck()299.check("= prim::TensorExprGroup_")300->check_not("= prim::TensorExprGroup_")301->run(*g);302}
303
304TEST(TEFuserPass, FuserPass_IgnoreUnknownShapeAtStart) {305WithCPUFuser cf;306const auto graph_string = R"IR(307graph(%x : Bool(8, strides=[1], device=cpu),
308%y : Bool(8, strides=[1], device=cpu)):
309%a : Bool(8, strides=[1], device=cpu) = aten::__and__(%x, %y)
310%b : Tensor = aten::__or__(%a, %y)
311return (%b)
312)IR";313auto g = std::make_shared<Graph>();314torch::jit::parseIR(graph_string, g.get());315g->lint();316FuseTensorExprs(g, /* min_group_size= */ 2);317testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);318}
319
320TEST(TEFuserPass, FuserPass_Where) {321WithCPUFuser cf;322const auto graph_string = R"IR(323graph(%x : Float(8, strides=[1], device=cpu),
324%y : Float(8, strides=[1], device=cpu),
325%z : Float(8, strides=[1], device=cpu)):
326%cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
327%b : Float(8, strides=[1], device=cpu) = aten::where(%cond, %y, %z)
328return (%b)
329)IR";330auto g = std::make_shared<Graph>();331torch::jit::parseIR(graph_string, g.get());332g->lint();333FuseTensorExprs(g, /* min_group_size= */ 2);334testing::FileCheck().check("prim::TensorExprGroup")->run(*g);335}
336
337TEST(TEFuserPass, FuserPass_WhereList) {338WithCPUFuser cf;339const auto graph_string = R"IR(340graph(%x : Float(8, strides=[1], device=cpu),
341%y : Float(8, strides=[1], device=cpu),
342%z : Float(8, strides=[1], device=cpu)):
343%cond : Bool(8, strides=[1], device=cpu) = aten::eq(%x, %y)
344%b : Tensor[] = aten::where(%cond)
345return (%b)
346)IR";347auto g = std::make_shared<Graph>();348torch::jit::parseIR(graph_string, g.get());349g->lint();350FuseTensorExprs(g, /* min_group_size= */ 2);351testing::FileCheck().check_not("prim::TensorExprGroup")->run(*g);352}
353
354TEST(TEFuserPass, DynamicShapeFusion) {355WithCPUFuser cf;356const auto graph_string = R"IR(357graph(%0 : Float(10, 5, strides=[5, 1], device=cpu),
358%1 : Float(10, 5, strides=[5, 1], device=cpu)):
359%2 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%0, %1)
360%3 : Float(10, 5, strides=[5, 1], device=cpu) = aten::mul(%2, %1)
361return (%3))IR";362auto g = std::make_shared<Graph>();363torch::jit::parseIR(graph_string, g.get());364
365g->lint();366FuseTensorExprs(367g,368/* min_group_size = */ 2,369/* add_composed_op = */ true,370/* fuse_to_dynamic_shapes = */ true);371Code code(g, "");372
373testing::FileCheck()374.check("prim::TensorExprDynamicGroup_")375->check("prim::TensorExprDynamicGuard")376->check("prim::TensorExprGroup_")377->run(*g);378
379auto run_and_compare = [&](const std::vector<at::Tensor>& inputs) {380TORCH_INTERNAL_ASSERT(inputs.size() == 2);381
382auto ref = at::mul(at::mul(inputs[0], inputs[1]), inputs[1]);383
384InterpreterState interp(code);385Stack stack(inputs.begin(), inputs.end());386interp.run(stack);387at::Tensor out = pop(stack).toTensor();388ASSERT_TRUE(at::allclose(out, ref));389};390
391std::vector<at::Tensor> inputs = {at::rand({10, 5}), at::rand({10, 5})};392run_and_compare(inputs);393
394std::vector<at::Tensor> inputs2 = {at::rand({20, 5}), at::rand({20, 5})};395run_and_compare(inputs2);396
397std::vector<at::Tensor> inputs3 = {at::rand({25, 60}), at::rand({25, 60})};398run_and_compare(inputs3);399}
400
401} // namespace jit402} // namespace torch403