pytorch
3711 строк · 111.9 Кб
1#include <ATen/core/dispatch/OperatorOptions.h>2#include <c10/core/ScalarType.h>3#include <gtest/gtest.h>4#include <torch/csrc/jit/ir/alias_analysis.h>5#include <torch/csrc/jit/ir/irparser.h>6#include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>7#include <torch/csrc/jit/runtime/static/impl.h>8#include <torch/csrc/jit/runtime/static/passes.h>9#include <torch/csrc/jit/testing/file_check.h>10#include <stdexcept>11
12#include "deep_wide_pt.h"13#include "test_utils.h"14
15using namespace caffe2;16using namespace torch;17using namespace torch::jit;18using namespace torch::jit::test;19using c10::IValue;20
21/*
22When adding a test for an operator implemented in static runtime, there are
23several things that you need to pay attention to:
24
251) if the op is an out variant, in the test script of the op,
26instead of:
27def forward(self, input):
28return myop(input)
29
30do:
31def forward(self, input):
32return myop(input).clone()
33
34This makes sure that the output of myop is managed by the memory planner and
35exercise the code path in the op impl that otherwise doesn't get exercised. The
36output of the model is not managed by the memory planner, because it needs to
37be returned to the client.
38
392) The memory planner rounds up the size of each Tensor's storage to multiples
40of 64 bytes (alignment requirement on AVX512). Make sure the sizes of the input
41tensors in args2 are big enough to trigger resizing.
42
433) for view ops such as aten::reshape or aten::to, if you want it to be
44replaced by the copy version with the ReplaceWithCopy pass in passes.h, you
45also want to make sure its output is not returned as the model output. The
46reason is that ReplaceWithCopy only replaces the op whose output is not an
47alias of the model output.
48*/
49
50C10_DECLARE_bool(static_runtime_enable_fast_math);51
52TEST(StaticRuntime, UnaryOps) {53const auto aten_sum = R"JIT(54def forward(self, input):
55return torch.sum(input).clone()
56)JIT";57
58const auto aten_sum_0 = R"JIT(59def forward(self, input):
60return torch.sum(input, 0).clone()
61)JIT";62
63const auto aten_sum_1 = R"JIT(64def forward(self, input):
65return torch.sum(input, 1).clone()
66)JIT";67
68const auto aten_sum_0_true = R"JIT(69def forward(self, input):
70return torch.sum(input, 0, True).clone()
71)JIT";72
73const auto aten_sum_1_true = R"JIT(74def forward(self, input):
75return torch.sum(input, 1, True).clone()
76)JIT";77
78auto a = at::randn({2, 3});79auto b = at::randn({3, 3, 6});80
81std::vector<IValue> args{a}, args2{b};82
83// sum84testStaticRuntime(aten_sum, args);85testStaticRuntime(aten_sum_0, args);86testStaticRuntime(aten_sum_1, args);87testStaticRuntime(aten_sum_0_true, args);88testStaticRuntime(aten_sum_1_true, args);89
90testStaticRuntime(aten_sum, args, args2, false, false, false);91testStaticRuntime(aten_sum_0, args, args2);92testStaticRuntime(aten_sum_1, args, args2);93testStaticRuntime(aten_sum_0_true, args, args2);94testStaticRuntime(aten_sum_1_true, args, args2);95}
96
97TEST(StaticRuntime, Max) {98auto src_max_reduce = R"JIT(99def forward(self, input):
100return torch.max(input).clone()
101)JIT";102
103auto src_max_dim = R"JIT(104def forward(self, input, dim: int):
105values, indices = torch.max(input, dim)
106return values.clone(), indices.clone()
107)JIT";108
109auto src_max_dim_keepdim = R"JIT(110def forward(self, input, dim: int):
111values, indices = torch.max(input, dim, keepdim=True)
112return values.clone(), indices.clone()
113)JIT";114
115auto src_max_pointwise = R"JIT(116def forward(self, input, other):
117return torch.max(input, other).clone()
118)JIT";119
120auto input = at::randn({2, 3, 2});121auto input_other = at::randn({2, 3, 2});122auto large_input = at::randn({8, 9, 10});123auto large_input_other = at::randn({8, 9, 10});124
125testStaticRuntime(src_max_reduce, {input});126testStaticRuntime(src_max_dim, {input, 1});127testStaticRuntime(src_max_dim, {input, 1}, {large_input, 0});128testStaticRuntime(src_max_dim_keepdim, {input, 0});129testStaticRuntime(src_max_dim_keepdim, {input, 0}, {large_input, 2});130testStaticRuntime(src_max_pointwise, {input, input_other});131testStaticRuntime(src_max_pointwise, {input, input_other}, {large_input, large_input_other});132}
133
134TEST(StaticRuntime, Mean) {135const auto src_default = R"JIT(136def forward(self, input):
137return torch.mean(input).clone()
138)JIT";139const auto src_dtype = R"JIT(140def forward(self, input, dtype: int):
141return torch.mean(input, dtype=dtype).clone()
142)JIT";143const auto src_dim = R"JIT(144def forward(self, input, dim: List[int]):
145return torch.mean(input, dim).clone()
146)JIT";147const auto src_dim_keepdim = R"JIT(148def forward(self, input, dim: List[int]):
149return torch.mean(input, dim, keepdim=True).clone()
150)JIT";151const auto src_dim_dtype = R"JIT(152def forward(self, input, dim: List[int], dtype: int):
153return torch.mean(input, dim, dtype=dtype).clone()
154)JIT";155
156auto input = at::randn({2, 3, 2});157auto large_input = at::randn({8, 7, 6, 8});158
159std::vector<IValue> args_default = {input};160std::vector<IValue> args_dtype = {input, torch::kFloat};161std::vector<IValue> args_dim = {input, c10::List<int64_t>{0, 2}};162std::vector<IValue> args_dim_keepdim = {input, c10::List<int64_t>{1, 2}};163std::vector<IValue> args_dim_dtype = {input, c10::List<int64_t>{0, 1}, torch::kBFloat16};164
165testStaticRuntime(src_default, args_default);166testStaticRuntime(src_dtype, args_dtype);167testStaticRuntime(src_dim, args_dim);168testStaticRuntime(src_dim_keepdim, args_dim_keepdim);169testStaticRuntime(src_dim_dtype, args_dim_dtype);170
171std::vector<IValue> large_args_dim = {large_input, c10::List<int64_t>{0, 3}};172std::vector<IValue> large_args_dim_keepdim = {large_input, c10::List<int64_t>{1, 2}};173std::vector<IValue> large_args_dim_dtype = {large_input, c10::List<int64_t>{1, 3}, torch::kBFloat16};174
175testStaticRuntime(src_dim, args_dim, large_args_dim);176testStaticRuntime(src_dim_keepdim, args_dim_keepdim, large_args_dim_keepdim);177testStaticRuntime(src_dim_dtype, args_dim_dtype, large_args_dim_dtype);178}
179
180TEST(StaticRuntime, Sigmoid) {181const auto sigmoid_script = R"JIT(182def forward(self, inp: Tensor):
183b = torch.sigmoid(inp).clone()
184return (b)
185)JIT";186auto a = at::randn({2, 3});187auto b = at::randn({4, 3, 2});188
189std::vector<IValue> args{a}, args2{b};190
191testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true);192testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true);193
194FLAGS_static_runtime_enable_fast_math = false;195testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true);196testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true);197FLAGS_static_runtime_enable_fast_math = true;198}
199
200TEST(StaticRuntime, Clone) {201/*202Clone called two times to trigger memory planner for output of first clone.
203The output of last op(second clone) is not managed by memory planner since it
204needs to be returned to the client and cannot be reused by planner.
205*/
206const auto clone_script_0 = R"JIT(207def forward(self, input):
208a = torch.clone(input).clone()
209return (a * a)
210)JIT";211
212// Case: clone with different set of memory_formats213const auto clone_script_1 = R"JIT(214def forward(self, input: Tensor, memory_format: int):
215a = torch.clone(input, memory_format=memory_format).clone()
216return (a * a)
217)JIT";218
219/*220Case: input stride set to 0 (due to expand op)
221calls native clone instead of out variant
222*/
223const auto clone_script_2 = R"JIT(224def forward(self, input: Tensor, other:Tensor):
225a = input.expand_as(other)
226return a.clone().clone()
227)JIT";228
229/*230Case: testing the case of sliced tensor for
231testing non-contiguous tensor storage
232*/
233const auto clone_script_3 = R"JIT(234def forward(self, input: Tensor):
235a = input[:, 0:10:2]
236return a.clone().clone()
237)JIT";238
239auto a = at::randn({2, 3});240auto b = at::randn({3, 2}).as_strided({3, 2}, {1, 3});241auto b_larger = at::randn({30, 20}).as_strided({30, 20}, {1, 3});242auto c = at::randn({1, 20, 13, 8});243auto d = at::randn({1, 0, 3, 4});244auto e = at::randn({2, 1});245auto f = at::randn({2, 10});246auto g = at::randn({3, 20});247std::vector<IValue> args_0{b, c10::MemoryFormat::Contiguous};248std::vector<IValue> args_1{b_larger, c10::MemoryFormat::Preserve};249std::vector<IValue> args_2{c, c10::MemoryFormat::ChannelsLast};250std::vector<IValue> args_3{d, c10::MemoryFormat::ChannelsLast};251std::vector<IValue> args_4{e,a};252std::vector<IValue> args_5{e,f};253
254testStaticRuntime(clone_script_0, {a});255testStaticRuntime(clone_script_0, {a}, {b_larger});256
257testStaticRuntime(clone_script_1, args_0);258testStaticRuntime(clone_script_1, args_1);259testStaticRuntime(clone_script_1, args_2);260testStaticRuntime(clone_script_1, args_3);261testStaticRuntime(clone_script_1, args_0, args_1);262testStaticRuntime(clone_script_1, args_3, args_2);263
264testStaticRuntime(clone_script_2, args_4);265testStaticRuntime(clone_script_2, args_4, args_5);266
267testStaticRuntime(clone_script_3, {f});268testStaticRuntime(clone_script_3, {f}, {g});269}
270
271TEST(StaticRuntime, Clamp) {272const auto clamp_script_1 = R"JIT(273def forward(self, inp: Tensor, min: int, max: int):
274a = torch.clamp(inp, min, max).clone()
275return (a)
276)JIT";277
278const auto clamp_script_2 = R"JIT(279def forward(self, inp: Tensor, min: Tensor, max: Tensor):
280a = torch.clamp(inp, min, max).clone()
281return (a)
282)JIT";283auto a = at::randn({2, 3});284auto max_t = at::full_like(a, 1);285auto min_t = at::full_like(a, -1);286
287auto b = at::randn({4, 3, 2});288auto max_t1 = at::full_like(b, 1);289auto min_t1 = at::full_like(b, -1);290
291testStaticRuntime(clamp_script_1, {a, -1, 1});292testStaticRuntime(clamp_script_2, {a, min_t, max_t});293
294testStaticRuntime(clamp_script_1, {a, -1, 1}, {b, -1, 1});295testStaticRuntime(clamp_script_2, {a, min_t, max_t}, {b, max_t1, min_t1});296}
297
298TEST(StaticRuntime, ClampMinOnly) {299const auto src = R"JIT(300def forward(self, inp: Tensor, min: float):
301a = torch.clamp(inp, min, None).clone()
302return (a)
303)JIT";304auto a = at::randn({2, 3});305auto b = at::randn({4, 3, 2});306testStaticRuntime(src, {a, 0.5});307testStaticRuntime(src, {a, 0.5}, {b, 0.25});308}
309
310TEST(StaticRuntime, ClampMaxOnly) {311const auto src = R"JIT(312def forward(self, inp: Tensor, max: float):
313a = torch.clamp(inp, None, max).clone()
314return (a)
315)JIT";316auto a = at::randn({2, 3});317auto b = at::randn({4, 3, 2});318testStaticRuntime(src, {a, 0.5});319testStaticRuntime(src, {a, 0.5}, {b, 0.25});320}
321
322TEST(StaticRuntime, ClampIntTensor) {323const auto src = R"JIT(324def forward(self, inp: Tensor, min: float, max: float):
325a = torch.clamp(inp, min, max).clone()
326return (a)
327)JIT";328auto a = at::randint(0, 20, {2, 3}, at::kFloat);329auto b = at::randint(0, 20, {4, 3, 2}, at::kFloat);330auto min = 5.0f;331auto max = 5.0f;332testStaticRuntime(src, {a, min, max});333testStaticRuntime(src, {a, min, max}, {b, min, max});334}
335
336TEST(StaticRuntime, LenWithTuple) {337const auto src = R"IR(338graph(%input : int[]):
339%res : int = aten::len(%input)
340return (%res)
341)IR";342
343testStaticRuntime(src, {c10::List<int64_t>(4)});344}
345
346TEST(StaticRuntime, LenWithTensor) {347const auto src = R"IR(348graph(%input : Tensor):
349%res : int = aten::len(%input)
350return (%res)
351)IR";352
353testStaticRuntime(src, {at::randn({2, 2, 2})});354}
355
356TEST(StaticRuntime, LenWithStr) {357const auto src = R"IR(358graph(%input : str):
359%res : int = aten::len(%input)
360return (%res)
361)IR";362
363testStaticRuntime(src, {"static_runtime"});364}
365
366TEST(StaticRuntime, LenWithDict_str) {367const auto script = R"JIT(368def forward(self, input: Dict[str, str]):
369return len(input)
370)JIT";371
372c10::Dict<std::string, std::string> dict;373dict.insert("abc", "123");374dict.insert("def", "456");375testStaticRuntime(script, {dict});376}
377
378TEST(StaticRuntime, LenWithDict_int) {379const auto script = R"JIT(380def forward(self, input: Dict[int, int]):
381return len(input)
382)JIT";383
384c10::Dict<int64_t, int64_t> dict;385dict.insert(0, 1);386dict.insert(2, 3);387testStaticRuntime(script, {dict});388}
389
390TEST(StaticRuntime, LenWithDict_bool) {391const auto script = R"JIT(392def forward(self, input: Dict[bool, bool]):
393return len(input)
394)JIT";395
396c10::Dict<bool, bool> dict;397dict.insert(true, false);398dict.insert(false, true);399testStaticRuntime(script, {dict});400}
401
402TEST(StaticRuntime, LenWithDict_float) {403const auto script = R"JIT(404def forward(self, input: Dict[float, float]):
405return len(input)
406)JIT";407
408c10::Dict<double, double> dict;409dict.insert(0.1, 0.9);410dict.insert(0.8, 0.18);411testStaticRuntime(script, {dict});412}
413
414TEST(StaticRuntime, LenWithDict_complex) {415const auto script = R"JIT(416def forward(self, input: Dict[complex, complex]):
417return len(input)
418)JIT";419
420c10::Dict<c10::complex<double>, c10::complex<double>> dict;421dict.insert(0.1, 0.4);422dict.insert(0.9, 0.45);423testStaticRuntime(script, {dict});424}
425
426TEST(StaticRuntime, LenWithDict_Tensor) {427const auto script = R"JIT(428def forward(self, input: Dict[Tensor, Tensor]):
429return len(input)
430)JIT";431
432c10::Dict<at::Tensor, at::Tensor> dict;433dict.insert(at::randn({1, 2}), at::randn({1, 2}));434dict.insert(at::randn({1, 2}), at::randn({1, 2}));435testStaticRuntime(script, {dict});436}
437
438TEST(StaticRuntime, Logit) {439// no nnc440const auto logit_script_1 = R"JIT(441def forward(self, inp: Tensor):
442a = torch.logit(inp).clone()
443return (a)
444)JIT";445
446// with nnc447const auto logit_script_2 = R"JIT(448def forward(self, inp: Tensor):
449a = torch.logit(inp, 1e-6).clone()
450return (a)
451)JIT";452
453// no nnc454const auto logit_script_3 = R"JIT(455def forward(self, inp: Tensor, eps: float):
456a = torch.logit(inp, eps).clone()
457return (a)
458)JIT";459auto a = at::ones({2, 3});460double b = 1e-6;461std::vector<IValue> args_1{a};462std::vector<IValue> args_2({a, b});463
464auto c = at::ones({4, 3, 2});465
466// logit467testStaticRuntime(logit_script_1, args_1);468testStaticRuntime(logit_script_2, args_1);469testStaticRuntime(logit_script_3, args_2);470
471testStaticRuntime(logit_script_1, args_1, {c});472testStaticRuntime(logit_script_2, args_1, {c});473testStaticRuntime(logit_script_3, args_2, {c, b});474}
475
476TEST(StaticRuntime, EmbeddingBag) {477const std::string embedding_bag_default = R"JIT(478def forward(self, a: Tensor, b: Tensor, c: Tensor):
479x, y, z, _ = torch.embedding_bag(a, b, c)
480return (x.clone(), y.clone(), z.clone(), _.clone())
481)JIT";482
483const std::string embedding_bag_mean = R"JIT(484def forward(self, a: Tensor, b: Tensor, c: Tensor):
485x, y, z, _ = torch.embedding_bag(a, b, c, False, 1)
486return (x.clone(), y.clone(), z.clone(), _.clone())
487)JIT";488
489const std::string embedding_bag_max = R"JIT(490def forward(self, a: Tensor, b: Tensor, c: Tensor):
491x, y, z, _ = torch.embedding_bag(a, b, c, False, 2)
492return (x.clone(), y.clone(), z.clone(), _.clone())
493)JIT";494
495const std::string embedding_bag_sum_last_offset = R"JIT(496def forward(self, a: Tensor, b: Tensor, c: Tensor):
497x, y, z, _ = torch.embedding_bag(a, b, c, False, 0, False, None, True)
498return (x.clone(), y.clone(), z.clone(), _.clone())
499)JIT";500
501const std::string embedding_bag_mean_last_offset = R"JIT(502def forward(self, a: Tensor, b: Tensor, c: Tensor):
503x, y, z, _ = torch.embedding_bag(a, b, c, False, 1, False, None, True)
504return (x.clone(), y.clone(), z.clone(), _.clone())
505)JIT";506
507const std::string embedding_bag_max_last_offset = R"JIT(508def forward(self, a: Tensor, b: Tensor, c: Tensor):
509x, y, z, _ = torch.embedding_bag(a, b, c, False, 2, False, None, True)
510return (x.clone(), y.clone(), z.clone(), _.clone())
511)JIT";512
513at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);514at::Tensor input = torch::tensor({0, 1, 0, 2});515at::Tensor offset = torch::tensor({0, 2, 4});516std::vector<IValue> args{weight, input, offset};517testStaticRuntime(embedding_bag_default, args);518testStaticRuntime(embedding_bag_mean, args);519testStaticRuntime(embedding_bag_max, args);520testStaticRuntime(embedding_bag_sum_last_offset, args);521testStaticRuntime(embedding_bag_mean_last_offset, args);522testStaticRuntime(embedding_bag_max_last_offset, args);523
524at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);525at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});526at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});527std::vector<IValue> args2{weight2, input2, offset2};528testStaticRuntime(embedding_bag_default, args, args2);529testStaticRuntime(embedding_bag_mean, args, args2);530testStaticRuntime(embedding_bag_max, args, args2);531testStaticRuntime(embedding_bag_sum_last_offset, args, args2);532testStaticRuntime(embedding_bag_mean_last_offset, args, args2);533testStaticRuntime(embedding_bag_max_last_offset, args, args2);534}
535
536TEST(StaticRuntime, EmbeddingBagWithManagedOutput) {537const std::string embedding_bag_managed_output = R"JIT(538def forward(self, a: Tensor, b: Tensor, c: Tensor):
539# The outputs of embedding_bag become an intermediate tensors
540# since they are not directly returned from the graph.
541x, y, z, _ = torch.embedding_bag(a, b, c)
542return x + x
543)JIT";544
545at::Tensor weight = torch::randn({3, 8}, at::ScalarType::Float);546at::Tensor input = torch::tensor({0, 1, 0, 2});547at::Tensor offset = torch::tensor({0, 2});548std::vector<IValue> args{weight, input, offset};549
550at::Tensor weight2 = torch::randn({6, 8}, at::ScalarType::Float);551at::Tensor input2 = torch::tensor({0, 1, 0, 2, 3, 4});552at::Tensor offset2 = torch::tensor({0, 2, 4, 5});553std::vector<IValue> args2{weight2, input2, offset2};554
555testStaticRuntime(embedding_bag_managed_output, args);556testStaticRuntime(embedding_bag_managed_output, args, args2);557}
558
559TEST(StaticRuntime, EmbeddingBagWithExtraneousOutput) {560const std::string embedding_bag_default_ir = R"IR(561graph(%weight, %indices, %offsets):
562%scale_grad_by_freq : bool = prim::Constant[value=0]()
563%mode : int = prim::Constant[value=0]()
564%sparse : bool = prim::Constant[value=0]()
565%per_sample_weights : NoneType = prim::Constant()
566%include_last_offset : bool = prim::Constant[value=0]()
567%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
568%none : NoneType = prim::Constant()
569%res : Tensor = aten::clone(%y0, %none)
570return (%res)
571)IR";572auto graph = getGraphFromIR(embedding_bag_default_ir);573RemoveUnnecessaryOutputs(graph);574torch::jit::testing::FileCheck()575.check("static_runtime::embedding_bag")576->run(*graph);577
578const std::string embedding_bag_mean_ir = R"IR(579graph(%weight, %indices, %offsets):
580%scale_grad_by_freq : bool = prim::Constant[value=0]()
581%mode : int = prim::Constant[value=1]()
582%sparse : bool = prim::Constant[value=0]()
583%per_sample_weights : NoneType = prim::Constant()
584%include_last_offset : bool = prim::Constant[value=0]()
585%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
586%none : NoneType = prim::Constant()
587%res : Tensor = aten::clone(%y0, %none)
588return (%res)
589)IR";590graph = getGraphFromIR(embedding_bag_mean_ir);591RemoveUnnecessaryOutputs(graph);592torch::jit::testing::FileCheck()593.check("static_runtime::embedding_bag")594->run(*graph);595
596const std::string embedding_bag_max_last_offset_ir = R"IR(597graph(%weight, %indices, %offsets):
598%scale_grad_by_freq : bool = prim::Constant[value=0]()
599%mode : int = prim::Constant[value=2]()
600%sparse : bool = prim::Constant[value=0]()
601%per_sample_weights : NoneType = prim::Constant()
602%include_last_offset : bool = prim::Constant[value=1]()
603%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
604%none : NoneType = prim::Constant()
605%res : Tensor = aten::clone(%y0, %none)
606return (%res)
607)IR";608graph = getGraphFromIR(embedding_bag_max_last_offset_ir);609RemoveUnnecessaryOutputs(graph);610torch::jit::testing::FileCheck()611.check("static_runtime::embedding_bag")612->run(*graph);613
614const std::string embedding_bag_normal_ir = R"IR(615graph(%weight, %indices, %offsets):
616%scale_grad_by_freq : bool = prim::Constant[value=0]()
617%mode : int = prim::Constant[value=0]()
618%sparse : bool = prim::Constant[value=0]()
619%per_sample_weights : NoneType = prim::Constant()
620%include_last_offset : bool = prim::Constant[value=0]()
621%y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
622%none : NoneType = prim::Constant()
623%res0 : Tensor = aten::clone(%y0, %none)
624%res1 : Tensor = aten::clone(%y1, %none)
625%res2 : Tensor = aten::clone(%y2, %none)
626%res3 : Tensor = aten::clone(%y3, %none)
627return (%res0, %res1, %res2, %res3)
628)IR";629graph = getGraphFromIR(embedding_bag_normal_ir);630RemoveUnnecessaryOutputs(graph);631torch::jit::testing::FileCheck()632.check_not("static_runtime::embedding_bag")633->run(*graph);634
635at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);636at::Tensor input = torch::tensor({0, 1, 0, 2});637at::Tensor offset = torch::tensor({0, 2, 4});638std::vector<IValue> args{weight, input, offset};639testStaticRuntime(embedding_bag_default_ir, args);640testStaticRuntime(embedding_bag_mean_ir, args);641testStaticRuntime(embedding_bag_max_last_offset_ir, args);642
643at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);644at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});645at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});646std::vector<IValue> args2{weight2, input2, offset2};647testStaticRuntime(embedding_bag_default_ir, args, args2);648testStaticRuntime(embedding_bag_mean_ir, args, args2);649testStaticRuntime(embedding_bag_max_last_offset_ir, args, args2);650}
651
652TEST(StaticRuntime, EmbeddingBagWithMixedInt32Int64Input) {653const std::string embedding_bag_default = R"JIT(654def forward(self, a: Tensor, b: Tensor, c: Tensor):
655x, y, z, _ = torch.embedding_bag(a, b, c)
656return (x.clone(), y.clone(), z.clone(), _.clone())
657)JIT";658auto weight = torch::randn({3, 11}, at::ScalarType::Float);659auto input = torch::tensor({0, 1, 0, 2}, at::ScalarType::Long);660auto offset = torch::tensor({0, 2, 4}, at::ScalarType::Int);661std::vector<IValue> args{weight, input, offset};662testStaticRuntime(embedding_bag_default, args);663}
664
665TEST(StaticRuntime, LayerNorm) {666const std::string layer_norm_with_weights = R"JIT(667def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor):
668return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone()
669)JIT";670
671const std::string layer_norm_without_weights = R"JIT(672def forward(self, input: Tensor, normalized_shape: List[int]):
673return torch.layer_norm(input, normalized_shape, None, None, 1e-05, False).clone()
674)JIT";675
676const std::string layer_norm_with_noncontiguous_input = R"JIT(677def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor):
678input = torch.transpose(input, 1, 2)
679return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone()
680)JIT";681
682const auto a = torch::rand({1, 2, 2, 2});683const auto b = torch::rand({3, 2, 2, 2});684for (int normalized_size : {2, 3}) {685std::vector<int64_t> normalized_shape(normalized_size, 2);686const auto weight = torch::rand(normalized_shape);687const auto bias = torch::rand(normalized_shape);688
689std::vector<IValue> args{a, normalized_shape, weight, bias};690std::vector<IValue> args1{b, normalized_shape, weight, bias};691testStaticRuntime(layer_norm_with_weights, args);692testStaticRuntime(layer_norm_with_weights, args, args1);693testStaticRuntime(layer_norm_with_noncontiguous_input, args);694
695args = {a, normalized_shape};696testStaticRuntime(layer_norm_without_weights, args);697testStaticRuntime(layer_norm_without_weights, args, {b, normalized_shape});698}699}
700
701TEST(StaticRuntime, Bmm) {702const auto bmm_script = R"JIT(703def forward(self, inp: Tensor, mat2: Tensor):
704return torch.bmm(inp, mat2).clone()
705)JIT";706
707auto a = at::randn({10, 4, 5});708auto b = at::randn({10, 5, 6});709
710auto c = at::randn({12, 5, 6});711auto d = at::randn({12, 6, 7});712
713std::vector<IValue> args{a, b};714std::vector<IValue> args1{c, d};715testStaticRuntime(bmm_script, args);716testStaticRuntime(bmm_script, args1);717testStaticRuntime(bmm_script, args, args1);718}
719
720TEST(StaticRuntime, Addmm) {721const auto addmm_script = R"JIT(722def forward(self, inp: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float):
723return torch.addmm(inp, mat1, mat2, alpha=alpha, beta=beta).clone()
724)JIT";725auto inp1 = at::randn({5});726auto mat1 = at::randn({3, 4});727auto mat2 = at::randn({4, 5});728
729auto inp2 = at::randn({3, 7});730auto mat3 = at::randn({3, 6});731auto mat4 = at::randn({6, 7});732
733std::vector<IValue> args{inp1, mat1, mat2, 1.0, 2.0};734std::vector<IValue> args1{inp2, mat3, mat4, 2.0, 1.0};735testStaticRuntime(addmm_script, args);736testStaticRuntime(addmm_script, args1);737testStaticRuntime(addmm_script, args, args1);738}
739
740TEST(StaticRuntime, Abs) {741const auto abs_script = R"JIT(742def forward(self, a):
743return a.abs().clone()
744)JIT";745auto a = at::randn({2, 3});746auto b = at::randn({4, 2, 3});747std::vector<IValue> args{a};748std::vector<IValue> args2{b};749testStaticRuntime(abs_script, args);750testStaticRuntime(abs_script, args, args2);751}
752
753TEST(StaticRuntime, Binary) {754const auto add_script = R"JIT(755def forward(self, a, b):
756c = a + b
757return (c.clone())
758)JIT";759
760const auto add_script_ints = R"JIT(761def forward(self, a: int, b: int):
762c = a + b
763d = c + 1
764return d
765)JIT";766
767const auto add_list_script = R"JIT(768def forward(self, a: List[int], b: List[int]):
769c = a + b
770return c[::]
771)JIT";772
773const auto list_construct_script = R"JIT(774def forward(self, a, b):
775return [a, b]
776)JIT";777
778const auto list_construct_script_2 = R"JIT(779def forward(self, a, b):
780c = a + a
781return [c, c]
782)JIT";783
784const auto list_construct_script_3 = R"JIT(785def forward(self, a, b):
786c = a + a
787return [c, c.flatten()]
788)JIT";789
790const auto list_unpack_script = R"JIT(791def forward(self, a, b):
792c = [a, b]
793x, y = c
794z = x + y
795return z.clone()
796)JIT";797
798const auto list_unpack_script_2 = R"JIT(799def forward(self, a, b):
800c = [a, b]
801x, y = c
802z = (x, y)
803return z
804)JIT";805
806const auto tuple_construct_script = R"JIT(807def forward(self, a, b):
808return (a, b)
809)JIT";810
811const auto tuple_construct_script_2 = R"JIT(812def forward(self, a, b):
813return (a.flatten(), b)
814)JIT";815
816auto a = at::randn({2, 3});817auto b = at::ones({2, 3});818
819auto c = at::randn({4, 2, 3});820auto d = at::ones({4, 2, 3});821
822std::vector<IValue> args{a, b};823
824testStaticRuntime(add_script, args);825testStaticRuntime(add_script_ints, {1, 2});826testStaticRuntime(add_script, args, {c, d});827testStaticRuntime(list_construct_script, args);828testStaticRuntime(list_construct_script_2, args);829testStaticRuntime(list_construct_script_3, args);830testStaticRuntime(list_unpack_script, args);831testStaticRuntime(list_unpack_script_2, args);832testStaticRuntime(tuple_construct_script, args);833testStaticRuntime(tuple_construct_script_2, args);834
835std::vector<IValue> list_args{836c10::List<int64_t>{1, 2, 3}, c10::List<int64_t>{4, 5, 6}};837testStaticRuntime(add_list_script, list_args);838}
839
840TEST(StaticRuntime, MatMul) {841const auto aten_matmul = R"JIT(842def forward(self, a: Tensor, b: Tensor):
843return torch.matmul(a, b).clone()
844)JIT";845
846// 1-D, 1-D847std::vector<IValue> args{at::randn({3}), at::randn({3})};848testStaticRuntime(aten_matmul, args);849// 2-D, 2-D850std::vector<IValue> args1 = {at::randn({3, 2}), at::randn({2, 3})};851testStaticRuntime(aten_matmul, args1);852// 1-D, 2-D853std::vector<IValue> args2 = {at::randn({3}), at::randn({3, 5})};854testStaticRuntime(aten_matmul, args2);855// 2-D, 1-D856std::vector<IValue> args3 = {at::randn({3, 5}), at::randn({5})};857testStaticRuntime(aten_matmul, args3);858// > 2-D , > 2-D859std::vector<IValue> args4 = {at::randn({3, 1, 4, 5}), at::randn({2, 5, 6})};860testStaticRuntime(aten_matmul, args4);861
862testStaticRuntime(aten_matmul, args3, args4);863}
864
865TEST(StaticRuntime, Sign) {866const auto sign_tensor = R"JIT(867def forward(self, input: Tensor):
868return torch.sign(input).clone()
869)JIT";870
871auto a = at::randn({2, 3});872auto b = at::randn({4, 3, 2});873
874std::vector<IValue> args{a};875testStaticRuntime(sign_tensor, args);876testStaticRuntime(sign_tensor, args, {b});877}
878
879TEST(StaticRuntime, Div) {880const auto div_tensor = R"JIT(881def forward(self, a: Tensor, b: Tensor):
882return torch.div(a, b).clone()
883)JIT";884
885const auto div_scalar = R"JIT(886def forward(self, a: Tensor, b: int):
887return torch.div(a, b).clone()
888)JIT";889
890const auto div_tensor_mode = R"JIT(891def forward(self, a: Tensor, b: Tensor, c: str):
892return torch.div(a, b, rounding_mode=c).clone()
893)JIT";894
895const auto div_scalar_mode = R"JIT(896def forward(self, a: Tensor, b: float, c: str):
897return torch.div(a, b, rounding_mode=c).clone()
898)JIT";899
900const auto div_strided = R"JIT(901def forward(self, a: Tensor, b: Tensor):
902a_strided = torch.transpose(a, 0, 1)
903b_strided = torch.transpose(b, 0, 1)
904return torch.div(a_strided, b_strided).clone()
905)JIT";906
907auto a = at::randn({2, 3});908auto b = at::randn({2, 3});909auto bs = at::randn({3, 2}).transpose(0, 1);910auto c = at::randn({4, 3, 2});911auto d = at::randn({4, 3, 2});912auto ds = at::randn({3, 4, 2}).transpose(0, 1);913
914std::vector<IValue> args0{a, b};915testStaticRuntime(div_tensor, args0);916testStaticRuntime(div_tensor, args0, {c, d});917
918testStaticRuntime(div_strided, args0);919testStaticRuntime(div_strided, args0, {c, d});920
921testStaticRuntime(div_tensor, {a, bs});922testStaticRuntime(div_tensor, {a, bs}, {c, ds});923
924std::vector<IValue> args1{a, 3};925testStaticRuntime(div_scalar, args1);926testStaticRuntime(div_scalar, args1, {c, 4});927
928std::vector<IValue> args2{a, b, "floor"};929testStaticRuntime(div_tensor_mode, args2);930testStaticRuntime(div_tensor_mode, args2, {c, d, "floor"});931
932std::vector<IValue> args3{a, 2.3, "trunc"};933testStaticRuntime(div_scalar_mode, args3);934testStaticRuntime(div_scalar_mode, args3, {c, 1.5, "trunc"});935}
936
937TEST(StaticRuntime, Mul) {938const auto mul_tensor = R"JIT(939def forward(self, a: Tensor, b: Tensor):
940return torch.mul(a, b).clone()
941)JIT";942
943const auto mul_scalar = R"JIT(944def forward(self, a: Tensor, b: int):
945return torch.mul(a, b).clone()
946)JIT";947
948const auto mul_list = R"JIT(949def forward(self, a: List[int], n: int):
950b = a * n
951return b[::]
952)JIT";953
954auto a = at::randn({3, 3});955auto b = at::randn({3, 3});956auto c = at::randn({3, 3, 3});957auto d = at::randn({3, 3, 3});958
959std::vector<IValue> tensor_args1{a, b};960std::vector<IValue> tensor_args2{c, d};961
962testStaticRuntime(mul_tensor, tensor_args1);963testStaticRuntime(mul_tensor, tensor_args1, tensor_args2);964
965std::vector<IValue> scalar_args1{a, 42};966std::vector<IValue> scalar_args2{c, 42};967
968testStaticRuntime(mul_scalar, scalar_args1);969testStaticRuntime(mul_scalar, scalar_args1, scalar_args2);970
971std::vector<IValue> list_args{c10::List<int64_t>{1, 2}, 3};972testStaticRuntime(mul_list, list_args);973}
974
975TEST(StaticRuntime, Log) {976const auto log_tensor = R"JIT(977def forward(self, inp: Tensor):
978a = torch.log(inp).clone()
979return (a)
980)JIT";981
982// Ensure that the input values are valid.983auto a = at::abs(at::randn({2, 3}));984auto b = at::abs(at::randn({4, 3, 2}));985
986std::vector<IValue> args{a};987testStaticRuntime(log_tensor, args);988testStaticRuntime(log_tensor, args, {b});989}
990
991TEST(StaticRuntime, Sub) {992const auto sub_tensor = R"JIT(993def forward(self, a: Tensor, b: Tensor):
994return torch.sub(a, b).clone()
995)JIT";996
997const auto sub_scalar = R"JIT(998def forward(self, a: Tensor, b: int):
999return torch.sub(a, b).clone()
1000)JIT";1001
1002const auto sub_tensor_alpha = R"JIT(1003def forward(self, a: Tensor, b: Tensor, c: float):
1004return torch.sub(a, b, alpha=c).clone()
1005)JIT";1006
1007const auto sub_scalar_alpha = R"JIT(1008def forward(self, a: Tensor, b: float, c: int):
1009return torch.sub(a, b, alpha=c).clone()
1010)JIT";1011
1012const auto sub_two_scalars = R"JIT(1013def forward(self, a: int, b: int):
1014return (a - b - b)
1015)JIT";1016
1017auto a = at::randn({2, 3});1018auto b = at::randn({2, 3});1019auto c = at::randn({4, 3, 2});1020auto d = at::randn({4, 3, 2});1021
1022std::vector<IValue> args0{a, b};1023testStaticRuntime(sub_tensor, args0);1024testStaticRuntime(sub_tensor, args0, {c, d});1025
1026std::vector<IValue> args1{a, 3};1027testStaticRuntime(sub_scalar, args1);1028testStaticRuntime(sub_scalar, args1, {c, 4});1029
1030std::vector<IValue> args2{a, b, 2.3};1031testStaticRuntime(sub_tensor_alpha, args2);1032testStaticRuntime(sub_tensor_alpha, {c, d, 3.1});1033
1034std::vector<IValue> args3{a, 2.3, 4};1035testStaticRuntime(sub_scalar_alpha, args3);1036testStaticRuntime(sub_scalar_alpha, {c, 1.3, 2});1037
1038std::vector<IValue> args4{1, 2};1039testStaticRuntime(sub_two_scalars, args4);1040}
1041
1042TEST(StaticRuntime, NanToNum) {1043const auto nan_to_num_script = R"JIT(1044def forward(self, a: Tensor, nan: float, posinf: float, neginf: float):
1045return torch.nan_to_num(a, nan, posinf, neginf).clone()
1046)JIT";1047
1048const auto inf = std::numeric_limits<double>::infinity();1049const auto nan = std::numeric_limits<double>::quiet_NaN();1050
1051auto a = torch::tensor({{1.0, nan}, {-inf, inf}});1052auto b = at::randn({3, 6});1053float* b_data = b.data_ptr<float>();1054b_data[0] = nan;1055b_data[4] = -inf;1056b_data[11] = inf;1057b_data[13] = nan;1058
1059std::vector<IValue> args1{a, 1.0, 2.0, -2.0};1060std::vector<IValue> args2{b, 1.0, 2.0, -2.0};1061
1062testStaticRuntime(1063nan_to_num_script,1064args1,1065/*args2*/ {},1066/*use_allclose*/ true,1067/*use_equalnan*/ true);1068testStaticRuntime(1069nan_to_num_script,1070args1,1071args2,1072/*use_allclose*/ true,1073/*use_equalnan*/ true);1074}
1075
1076TEST(StaticRuntime, Stack) {1077const auto stack_dim = R"JIT(1078def forward(self, a: Tensor, b: Tensor, dim: int):
1079inputs = [a]
1080inputs.append(b) # mutation to avoid using VarStack
1081return torch.stack(inputs, dim = dim).clone()
1082)JIT";1083
1084const auto stack_three = R"JIT(1085def forward(self, a: Tensor, b: Tensor, c: Tensor):
1086inputs = [a, b]
1087inputs.append(c) # mutation to avoid using VarStack
1088return torch.stack(inputs).clone()
1089)JIT";1090
1091auto a = at::randn({2, 2});1092auto b = at::randn({2, 2});1093auto c = at::randn({2, 2});1094
1095auto d = at::randn({3, 3, 3});1096auto e = at::randn({3, 3, 3});1097auto f = at::randn({3, 3, 3});1098
1099std::vector<IValue> args1_dim{a, b, 0};1100std::vector<IValue> args2_dim{d, e, 1};1101std::vector<IValue> args_dim_negative{d, e, -1};1102
1103std::vector<IValue> args1_three_tensors{a, b, c};1104std::vector<IValue> args2_three_tensors{d, e, f};1105
1106testStaticRuntime(stack_dim, args1_dim);1107testStaticRuntime(stack_dim, args1_dim, args2_dim);1108
1109testStaticRuntime(stack_dim, args_dim_negative);1110
1111testStaticRuntime(stack_three, args1_three_tensors);1112testStaticRuntime(stack_three, args1_three_tensors, args2_three_tensors);1113}
1114
1115TEST(StaticRuntime, ReLU) {1116const auto relu_script = R"JIT(1117def forward(self, a: Tensor):
1118return torch.relu(a).clone()
1119)JIT";1120auto a = at::randint(-10, 10, {2, 4});1121auto b = at::randint(-10, 10, {3, 6});1122
1123std::vector<IValue> args1{a};1124std::vector<IValue> args2{b};1125
1126testStaticRuntime(relu_script, args1);1127testStaticRuntime(relu_script, args1, args2);1128}
1129
1130TEST(StaticRuntime, Tanh) {1131const auto tanh_script = R"JIT(1132def forward(self, a):
1133return torch.tanh(a).clone()
1134)JIT";1135auto a = at::randn({2, 2});1136auto b = at::randn({3, 3, 3});1137
1138std::vector<IValue> args1{a};1139std::vector<IValue> args2{b};1140
1141testStaticRuntime(tanh_script, args1, /*args2*/ {}, /*use_allclose*/ true);1142testStaticRuntime(tanh_script, args1, args2, /*use_allclose*/ true);1143}
1144
1145TEST(StaticRuntime, Norm) {1146const auto norm_2arg = R"JIT(1147def forward(self, a: Tensor, p: int):
1148return torch.norm(a, p).clone()
1149)JIT";1150
1151const auto norm_3arg = R"JIT(1152def forward(self, a: Tensor, p: int, dtype: int):
1153return torch.norm(a, p, dtype=dtype).clone()
1154)JIT";1155
1156const auto norm_4arg = R"JIT(1157def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool):
1158return torch.norm(a, p, dim, keepdim).clone()
1159)JIT";1160
1161const auto norm_5arg = R"JIT(1162def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool, dtype: int):
1163return torch.norm(a, p, dim, keepdim, dtype=dtype).clone()
1164)JIT";1165
1166auto a = at::randn({2, 3});1167auto b = at::randn({4, 3, 5});1168auto dim = std::vector<int64_t>({1});1169auto dtype = at::ScalarType::Float;1170
1171std::vector<IValue> args2{a, 2};1172testStaticRuntime(norm_2arg, args2);1173testStaticRuntime(norm_2arg, args2, {b, 2}, false, false, false);1174
1175std::vector<IValue> args3{a, 2, dtype};1176testStaticRuntime(norm_3arg, args3);1177testStaticRuntime(norm_3arg, args3, {b, 2, dtype}, false, false, false);1178
1179std::vector<IValue> args4{a, 3, dim, false};1180testStaticRuntime(norm_4arg, args4);1181testStaticRuntime(norm_4arg, args4, {b, 3, dim, false});1182
1183std::vector<IValue> args5{a, 4, dim, true, dtype};1184testStaticRuntime(norm_5arg, args5);1185testStaticRuntime(norm_5arg, args5, {b, 4, dim, true, dtype});1186}
1187
1188TEST(StaticRuntime, Reshape) {1189const auto reshape_script_1 = R"JIT(1190def forward(self, a: Tensor, shape: List[int]):
1191b = a.reshape(shape)
1192return b + b
1193)JIT";1194
1195const auto reshape_script_2 = R"JIT(1196def forward(self, a: Tensor, shape: List[int]):
1197b = a.transpose(0, 1)
1198return b.reshape(shape)
1199)JIT";1200
1201const auto reshape_script_3 = R"JIT(1202def forward(self, inp: Tensor, shape: List[int]):
1203a = inp + inp
1204b = a.reshape(shape)
1205c = a.reshape(shape)
1206d = c + c
1207e = d + d
1208f = e * e
1209g = f * f
1210return b.reshape(shape), g
1211)JIT";1212
1213// exercise reshape_copy and flatten_copy1214const auto reshape_script_4 = R"JIT(1215def forward(self, inp: Tensor, shape: List[int]):
1216k = inp + inp
1217a = k + k
1218b = a.reshape(shape)
1219c = a.flatten().reshape(shape)
1220return b + c
1221)JIT";1222
1223// exercise reshape_copy1224const auto reshape_script_5 = R"JIT(1225def forward(self, inp: Tensor, shape: List[int]):
1226a = inp + inp
1227b = a.reshape(shape)
1228c = a.reshape(shape).relu()
1229d = c + c
1230e = d + d
1231f = e * e
1232g = f * f
1233return g
1234)JIT";1235
1236const auto reshape_inplace_script = R"JIT(1237def forward(self, inp: Tensor, shape: List[int]):
1238a = inp + inp
1239b = a.reshape(shape)
1240c = b.sigmoid_()
1241d = c + c
1242e = a + a
1243f = b + b
1244return (d, e, f)
1245)JIT";1246
1247// b is in_contiguous1248const auto reshape_incontiguous_script = R"JIT(1249def forward(self, a: Tensor, shape: List[int]):
1250b = a.transpose(0, 1)
1251c = b.reshape(shape)
1252c = c.relu()
1253return (c)
1254)JIT";1255
1256auto a = at::randn({2, 3});1257auto b = std::vector<int64_t>({3, 2});1258std::vector<IValue> args{a, b};1259
1260auto c = at::randn({4, 5});1261auto d = std::vector<int64_t>({5, 1, 2, 2});1262std::vector<IValue> args1{c, d};1263
1264testStaticRuntime(reshape_script_1, args);1265testStaticRuntime(reshape_script_2, args);1266testStaticRuntime(reshape_script_3, args);1267testStaticRuntime(reshape_script_4, args);1268testStaticRuntime(reshape_script_5, args);1269testStaticRuntime(reshape_inplace_script, args);1270testStaticRuntime(reshape_incontiguous_script, args);1271
1272testStaticRuntime(reshape_script_1, args, args1);1273testStaticRuntime(reshape_script_2, args, args1);1274testStaticRuntime(reshape_script_3, args, args1);1275testStaticRuntime(reshape_script_4, args, args1);1276testStaticRuntime(reshape_script_5, args, args1);1277testStaticRuntime(reshape_inplace_script, args, args1);1278testStaticRuntime(reshape_incontiguous_script, args, args1);1279}
1280
1281TEST(StaticRuntime, Repeat) {1282const std::string repeat = R"JIT(1283def forward(self, a: Tensor, repeats: List[int]):
1284return torch.repeat(a, repeats).clone()
1285)JIT";1286
1287auto a = at::randn({2, 3});1288auto b = at::randn({4, 3});1289auto c = std::vector<int64_t>({1, 2});1290auto d = std::vector<int64_t>({2, 3});1291std::vector<IValue> args1{a, c};1292std::vector<IValue> args2{b, d};1293
1294testStaticRuntime(repeat, args1);1295testStaticRuntime(repeat, args2);1296testStaticRuntime(repeat, args1, args2);1297}
1298
1299TEST(StaticRuntime, Flatten) {1300// exercise flatten_copy1301const auto flatten_script_1 = R"JIT(1302def forward(self, a: Tensor, start_dim: int, end_dim: int):
1303b = a * a
1304c = torch.flatten(b, start_dim, end_dim)
1305d = torch.relu(c)
1306return d
1307)JIT";1308
1309const auto flatten_script_2 = R"JIT(1310def forward(self, a: Tensor, start_dim: int, end_dim: int):
1311b = a.transpose(0, 1)
1312return torch.flatten(b, start_dim, end_dim).clone()
1313)JIT";1314
1315auto test_flatten =1316[&](std::vector<int64_t> shape, int64_t start_dim, int64_t end_dim) {1317std::vector<int64_t> shape1(shape);1318if (shape1.size() > 0) {1319shape1[0] *= 6;1320}1321auto a = at::randn(shape);1322auto b = at::randn(shape1);1323std::vector<IValue> args{a, start_dim, end_dim};1324bool check_resize = shape1.size() > 0;1325testStaticRuntime(flatten_script_1, args);1326testStaticRuntime(1327flatten_script_1,1328args,1329{b, start_dim, end_dim},1330false, /* use_allclose */1331false, /* use_equalnan */1332check_resize);1333if (shape.size() > 2) {1334testStaticRuntime(flatten_script_2, args);1335testStaticRuntime(flatten_script_2, args, {b, start_dim, end_dim});1336}1337};1338
1339test_flatten({2, 3}, 0, 1);1340test_flatten({2, 1, 3}, 1, 2);1341test_flatten({0, 1, 3, 0}, 1, 2);1342test_flatten({2, 3}, 1, 1);1343test_flatten({}, 0, 0);1344}
1345
1346TEST(StaticRuntime, pow) {1347const auto pow_script_ten_sca = R"JIT(1348def forward(self, input : Tensor, exponent : int):
1349return torch.pow(input, exponent).clone()
1350)JIT";1351
1352const auto pow_script_ten_ten = R"JIT(1353def forward(self, input : Tensor, exponent : Tensor):
1354return torch.pow(input, exponent).clone()
1355)JIT";1356
1357const auto pow_script_sca_ten = R"JIT(1358def forward(self, input : int, exponent : Tensor):
1359return torch.pow(input, exponent).clone()
1360)JIT";1361
1362auto a = at::randn({2, 3});1363auto b = at::randn({2, 3});1364auto c = at::randn({4, 3, 2});1365auto d = at::randn({4, 3, 2});1366
1367std::vector<IValue> args0{a, 4};1368testStaticRuntime(pow_script_ten_sca, args0);1369testStaticRuntime(pow_script_ten_sca, args0, {c, 4});1370
1371std::vector<IValue> args1{at::abs(a), b};1372testStaticRuntime(pow_script_ten_ten, args1);1373testStaticRuntime(pow_script_ten_ten, args1, {at::abs(c), d});1374
1375std::vector<IValue> args2{5, b};1376testStaticRuntime(pow_script_sca_ten, args2);1377testStaticRuntime(pow_script_sca_ten, args2, {3, d});1378}
1379
1380TEST(StaticRuntime, to) {1381const auto to_script_dtype = R"JIT(1382def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
1383a = input + input
1384return torch.to(a, dtype, non_blocking, copy, memory_format).clone()
1385)JIT";1386
1387const auto to_script_dtype_strided = R"JIT(1388def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
1389b = input.permute(0, 2, 3, 1)
1390return torch.to(b, dtype, non_blocking, copy, memory_format).clone()
1391)JIT";1392
1393const auto to_script_prim_dtype = R"JIT(1394def forward(self, input:Tensor, dtype: Optional[int], non_blocking: bool, copy: bool):
1395a = input + input
1396return torch.to(a, dtype, non_blocking, copy).clone()
1397)JIT";1398
1399const auto to_script_other = R"JIT(1400def forward(self, input:Tensor, other: Tensor, non_blocking: bool, copy: bool, memory_format: int):
1401a = input + input
1402return torch.to(a, other, non_blocking, copy, memory_format).clone()
1403)JIT";1404
1405// if input is float tensor, b could be alias of a1406const auto to_script_alias = R"JIT(1407def forward(self, input:Tensor):
1408a = input + input
1409b = a.float()
1410c = b * b
1411return (c)
1412)JIT";1413
1414const auto to_script_fails_managed_output_check = R"JIT(1415def forward(self, a, b):
1416d = a.half() * b.half()
1417e = d.float()
1418return e
1419)JIT";1420
1421const auto to_script_select_tensor_output_into_tuple = R"JIT(1422def forward(self, a, b):
1423d = a.half() * b.half()
1424e = d.float()
1425return (d, e)
1426)JIT";1427
1428const auto to_script_memory_planning_fail = R"JIT(1429def forward(self, a, b):
1430d = a.half() * b.half()
1431e = d.float().relu()
1432return e
1433)JIT";1434
1435auto test_to = [&](at::ScalarType b, bool c, bool d, c10::MemoryFormat e) {1436auto a = at::randn({4, 3, 1, 2});1437auto other = at::randn({4, 3, 1, 2}).to(b);1438auto a2 = at::randn({3, 2, 2, 4});1439auto a2_other = at::randn({3, 2, 2, 4}).to(b);1440
1441std::vector<IValue> args0{a, b, c, d, e};1442std::vector<IValue> args1{a, b, c, d};1443std::vector<IValue> args2{a, other, c, d, e};1444std::vector<IValue> args2WithDifferentOtherType{1445a, at::randn({4, 3, 1, 2}, ScalarType::Double), c, d, e};1446std::vector<IValue> args3{a, std::nullopt, c, d};1447
1448std::vector<IValue> args0WithInt{a, ScalarType::Int, c, d, e};1449testStaticRuntime(1450to_script_dtype,1451args0,1452args0WithInt,1453/* default for use_allclose */ false,1454/* default for use_equalnan */ false,1455/* check_resize */ false);1456testStaticRuntime(to_script_dtype_strided, args0);1457testStaticRuntime(to_script_prim_dtype, args1);1458if (!d) {1459testStaticRuntime(to_script_prim_dtype, args3);1460}1461// Second set of args tests case where the `other` tensor's dtype1462// changes between iterations.1463testStaticRuntime(1464to_script_other,1465args2,1466args2WithDifferentOtherType,1467/* default for use_allclose */ false,1468/* default for use_equalnan */ false,1469/* check_resize */ false);1470testStaticRuntime(to_script_alias, {a});1471
1472testStaticRuntime(to_script_memory_planning_fail, {a, a});1473testStaticRuntime(to_script_fails_managed_output_check, {a, a});1474testStaticRuntime(to_script_select_tensor_output_into_tuple, {a, a});1475
1476// dynamic shapes1477testStaticRuntime(to_script_dtype, args0, {a2, b, c, d, e});1478testStaticRuntime(to_script_dtype_strided, args0, {a2, b, c, d, e});1479testStaticRuntime(to_script_prim_dtype, args1, {a2, b, c, d});1480if (!d) {1481testStaticRuntime(to_script_prim_dtype, args3, {a2, std::nullopt, c, d});1482}1483testStaticRuntime(to_script_other, args2, {a2, a2_other, c, d, e});1484testStaticRuntime(to_script_alias, {a}, {a2});1485};1486for (const bool non_blocking : {false, true}) {1487for (const bool copy : {false, true}) {1488// float->float, NCHW->NHWC1489test_to(1490at::ScalarType::Float,1491non_blocking,1492copy,1493c10::MemoryFormat::ChannelsLast);1494// float->half1495test_to(1496at::ScalarType::Half,1497non_blocking,1498copy,1499c10::MemoryFormat::Preserve);1500// float->float1501test_to(1502at::ScalarType::Float,1503non_blocking,1504copy,1505c10::MemoryFormat::Contiguous);1506test_to(1507at::ScalarType::Bool,1508non_blocking,1509copy,1510c10::MemoryFormat::Contiguous);1511// TODO: check if fbgemm is enabled properly in this case1512// half->float, NCHW->NHWC1513test_to(1514at::ScalarType::Half,1515non_blocking,1516copy,1517c10::MemoryFormat::ChannelsLast);1518}1519}1520}
1521
1522TEST(StaticRuntime, ExpandAs) {1523const auto expand_as_script = R"JIT(1524def forward(self, input: Tensor, other:Tensor):
1525a = input.expand_as(other)
1526return a.clone()
1527)JIT";1528
1529auto a = at::randn({3, 1});1530auto b = at::randn({3, 2});1531auto c = at::randn({4, 1});1532auto d = at::randn({4, 2});1533std::vector<IValue> args{a, b};1534std::vector<IValue> args2{c, d};1535testStaticRuntime(expand_as_script, args);1536testStaticRuntime(expand_as_script, args, args2);1537}
1538
1539TEST(StaticRuntime, Full) {1540const auto full_script = R"JIT(1541def forward(self,
1542size: List[int],
1543fill_value: int,
1544dtype: Optional[int],
1545layout: Optional[int],
1546device: Optional[Device],
1547pin_memory: Optional[bool]):
1548a = torch.full(size,
1549fill_value,
1550dtype=dtype,
1551layout=layout,
1552device=device,
1553pin_memory=pin_memory)
1554return (a.clone())
1555)JIT";1556
1557auto cpu = at::Device(DeviceType::CPU);1558c10::List<int64_t> size0{2, 5};1559std::vector<IValue> args{1560size0, 4, at::ScalarType::Int, at::kStrided, cpu, false};1561std::vector<IValue> args1{1562size0, 4, at::ScalarType::Float, at::kStrided, cpu, false};1563c10::List<int64_t> size1{5, 6};1564std::vector<IValue> args2{1565size1, 5, at::ScalarType::Float, at::kStrided, cpu, false};1566testStaticRuntime(full_script, args);1567testStaticRuntime(1568full_script,1569args,1570args1,1571/*use_allclose=*/false,1572/*use_equalnan=*/false,1573/*check_resize=*/false);1574testStaticRuntime(full_script, args, args2);1575}
1576
1577TEST(StaticRuntime, FullLike) {1578const auto full_like_script = R"JIT(1579def forward(self,
1580a: Tensor,
1581fill_value: int,
1582dtype: Optional[int],
1583layout: Optional[int],
1584device: Optional[Device],
1585pin_memory: Optional[bool],
1586memory_format: Optional[int]):
1587b = torch.full_like(a,
1588fill_value,
1589dtype=dtype,
1590layout=layout,
1591device=device,
1592pin_memory=pin_memory,
1593memory_format=memory_format)
1594return (b.clone())
1595)JIT";1596
1597auto a = at::randn({2, 3});1598auto b = at::randn({3, 4, 2});1599auto cpu = at::Device(DeviceType::CPU);1600std::vector<IValue> args{1601a,16024,1603at::ScalarType::Int,1604at::kStrided,1605cpu,1606false,1607c10::MemoryFormat::Contiguous};1608std::vector<IValue> args1{1609a,16104,1611at::ScalarType::Float,1612at::kStrided,1613cpu,1614false,1615c10::MemoryFormat::Contiguous};1616std::vector<IValue> args2{1617b,16184,1619at::ScalarType::Float,1620at::kStrided,1621cpu,1622false,1623c10::MemoryFormat::Contiguous};1624testStaticRuntime(full_like_script, args);1625testStaticRuntime(1626full_like_script,1627args,1628args1,1629/*use_allclose=*/false,1630/*use_equalnan=*/false,1631/*check_resize=*/false);1632testStaticRuntime(full_like_script, args, args2);1633}
1634
1635TEST(StaticRuntime, Ones) {1636const auto script = R"JIT(1637def forward(self,
1638size: List[int],
1639dtype: Optional[int],
1640layout: Optional[int],
1641device: Optional[Device],
1642pin_memory: Optional[bool]):
1643a = torch.ones(size,
1644dtype=dtype,
1645layout=layout,
1646device=device,
1647pin_memory=pin_memory)
1648return (a.clone())
1649)JIT";1650
1651auto dtype = at::ScalarType::Int;1652auto cpu = at::Device(DeviceType::CPU);1653c10::List<int64_t> size0{2, 5};1654std::vector<IValue> args{size0, dtype, at::kStrided, cpu, false};1655c10::List<int64_t> size1{5, 6};1656std::vector<IValue> args2{size1, dtype, at::kStrided, cpu, false};1657testStaticRuntime(script, args);1658testStaticRuntime(script, args, args2);1659}
1660
1661TEST(StaticRuntime, OnesLike) {1662const auto script = R"JIT(1663def forward(self,
1664input: Tensor,
1665dtype: Optional[int],
1666layout: Optional[int],
1667device: Optional[Device],
1668pin_memory: Optional[bool],
1669memory_format: Optional[int]):
1670a = torch.ones_like(input,
1671dtype=dtype,
1672layout=layout,
1673device=device,
1674pin_memory=pin_memory,
1675memory_format=memory_format)
1676return (a.clone())
1677)JIT";1678
1679auto cpu = at::Device(DeviceType::CPU);1680auto input0 = at::randn({2, 5});1681std::vector<IValue> args{1682input0,1683at::ScalarType::Int,1684at::kStrided,1685cpu,1686false,1687c10::MemoryFormat::Contiguous};1688std::vector<IValue> args1{1689input0,1690at::ScalarType::Float,1691at::kStrided,1692cpu,1693false,1694c10::MemoryFormat::Contiguous};1695auto input1 = at::randn({5, 6});1696std::vector<IValue> args2{1697input1,1698at::ScalarType::Float,1699at::kStrided,1700cpu,1701false,1702c10::MemoryFormat::Contiguous};1703testStaticRuntime(script, args);1704testStaticRuntime(1705script,1706args,1707args1,1708/*use_allclose=*/false,1709/*use_equalnan=*/false,1710/*check_resize=*/false);1711testStaticRuntime(script, args, args2);1712}
1713
1714TEST(StaticRuntime, Zeros) {1715const auto script = R"JIT(1716def forward(self,
1717size: List[int],
1718dtype: Optional[int],
1719layout: Optional[int],
1720device: Optional[Device],
1721pin_memory: Optional[bool]):
1722a = torch.zeros(size,
1723dtype=dtype,
1724layout=layout,
1725device=device,
1726pin_memory=pin_memory)
1727return (a.clone())
1728)JIT";1729
1730auto cpu = at::Device(DeviceType::CPU);1731c10::List<int64_t> size0{2, 5};1732std::vector<IValue> args{1733size0, at::ScalarType::Int, at::kStrided, cpu, false};1734std::vector<IValue> args1{1735size0, at::ScalarType::Float, at::kStrided, cpu, false};1736c10::List<int64_t> size1{5, 6};1737std::vector<IValue> args2{1738size1, at::ScalarType::Float, at::kStrided, cpu, false};1739testStaticRuntime(script, args);1740testStaticRuntime(1741script,1742args,1743args1,1744/*use_allclose=*/false,1745/*use_equalnan=*/false,1746/*check_resize=*/false);1747testStaticRuntime(script, args, args2);1748}
1749
1750TEST(StaticRuntime, Linear) {1751const auto linear_script = R"JIT(1752def forward(self, inp: Tensor, weights: Tensor, bias: Optional[Tensor]) -> Tensor:
1753return torch.linear(inp, weights, bias).clone()
1754)JIT";1755
1756auto input = at::randn({1, 2});1757auto weights = at::randn({1, 2});1758auto bias = at::randn({1, 1});1759
1760std::vector<IValue> args{input, weights, bias};1761std::vector<IValue> args_no_bias{input, weights, std::nullopt};1762
1763auto input2 = at::randn({6, 3});1764auto weights2 = at::randn({6, 3});1765auto bias2 = at::randn({6, 6});1766
1767std::vector<IValue> args2{input2, weights2, bias2};1768std::vector<IValue> args2_no_bias{input2, weights2, std::nullopt};1769
1770testStaticRuntime(linear_script, args);1771testStaticRuntime(linear_script, args_no_bias);1772
1773testStaticRuntime(linear_script, args, args2);1774testStaticRuntime(linear_script, args, args2_no_bias);1775}
1776
1777TEST(StaticRuntime, VarCat) {1778const auto var_cat_script = R"JIT(1779def forward(self, inp1: Tensor, inp2: Tensor, dim: int):
1780return torch.cat([inp1, inp2], dim).clone()
1781)JIT";1782
1783// 2D tensors - cat dim = 01784std::vector<IValue> args1 = {at::randn({4, 6}), at::randn({5, 6}), 0};1785testStaticRuntime(var_cat_script, args1);1786
1787// 3D tensors - cat dim = 11788std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 8, 6}), 1};1789testStaticRuntime(var_cat_script, args2);1790
1791// 3D tensors - cat dim = 21792std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), 2};1793testStaticRuntime(var_cat_script, args3);1794
1795// Negative dim1796std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), -1};1797testStaticRuntime(var_cat_script, args4);1798
1799testStaticRuntime(var_cat_script, args1, args2);1800}
1801
1802TEST(StaticRuntime, LeakyReLU) {1803torch::jit::Module mod = getLeakyReLUConstScriptModel();1804auto inputs = torch::randn({2, 2});1805
1806// run jit graph executor1807std::vector<at::IValue> input_ivalues({inputs});1808at::Tensor output_1 = mod.forward(input_ivalues).toTensor();1809
1810// run static runtime1811std::vector<c10::IValue> input_tensors({inputs});1812torch::jit::StaticModule smod(mod);1813at::Tensor output_2 = smod(input_tensors, {}).toTensor();1814smod.runtime().check_for_memory_leak();1815EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));1816}
1817
1818static ProcessedNodeInputs createProcessedNodeInputs(1819c10::ArrayRef<uint16_t> inputs) {1820ProcessedNodeInputs result(inputs.size());1821for (const auto idx : c10::irange(inputs.size())) {1822result[idx] = inputs[idx];1823}1824return result;1825}
1826
1827static void checkProcessedNodeInputs(1828const ProcessedNodeInputs& io,1829c10::ArrayRef<uint16_t> inputs) {1830ASSERT_EQ(inputs.size(), io.size());1831for (const auto idx : c10::irange(inputs.size())) {1832EXPECT_EQ(inputs[idx], io[idx]);1833}1834}
1835
1836static void testProcessedNodeInputsRoundTrip(c10::ArrayRef<uint16_t> inputs) {1837auto io = createProcessedNodeInputs(inputs);1838checkProcessedNodeInputs(io, inputs);1839
1840ProcessedNodeInputs copied(io);1841checkProcessedNodeInputs(copied, inputs);1842ProcessedNodeInputs moved(std::move(io));1843checkProcessedNodeInputs(moved, inputs);1844}
1845
1846TEST(ProcessedNodeInputs, Basic) {1847std::vector<std::vector<uint16_t>> testCases = {1848{}, // empty1849{0xABCD, 0x5a5a}, // inline1850{0x11, 0x22, 0x33, 0x44, 0x55}, // max inline size1851{0x11, 0x22, 0x33, 0x44, 0x55, 0x66}, // minimum outline size1852std::vector<uint16_t>(100, 0x5a), // large outline size1853};1854
1855for (const auto& values : testCases) {1856testProcessedNodeInputsRoundTrip(values);1857for (const auto& values2 : testCases) {1858auto from = createProcessedNodeInputs(values);1859auto to = createProcessedNodeInputs(values2);1860
1861to = from;1862checkProcessedNodeInputs(to, values);1863
1864auto toMoveInto = createProcessedNodeInputs(values2);1865toMoveInto = std::move(from);1866checkProcessedNodeInputs(toMoveInto, values);1867}1868}1869}
1870
1871TEST(StaticRuntime, isinstance) {1872const auto isinstance_int_script = R"JIT(1873def forward(self, a: Any):
1874return isinstance(a, int)
1875)JIT";1876
1877const auto isinstance_tensor_script = R"JIT(1878def forward(self, a: Any):
1879return isinstance(a, torch.Tensor)
1880)JIT";1881
1882const auto isinstance_many_types_script = R"JIT(1883def forward(self, a: Any):
1884return isinstance(a, (bool, int))
1885)JIT";1886
1887auto a = at::randn({2, 2});1888auto b = at::randn({2, 2, 2});1889
1890std::vector<at::IValue> args{a};1891std::vector<at::IValue> args2{b};1892
1893testStaticRuntime(isinstance_int_script, args);1894testStaticRuntime(isinstance_int_script, args, args2);1895
1896testStaticRuntime(isinstance_tensor_script, args);1897testStaticRuntime(isinstance_tensor_script, args, args2);1898
1899testStaticRuntime(isinstance_many_types_script, args);1900testStaticRuntime(isinstance_many_types_script, args, args2);1901}
1902
1903TEST(StaticRuntime, TypeCheck) {1904const auto typecheck_ir = R"IR(1905graph(%a.1 : Tensor,
1906%b.1 : Tensor):
1907%t0 : Float(2, 2, strides=[2, 1], device=cpu), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu), Float(3, 3, strides=[3, 1])]](%a.1, %b.1)
1908return (%t0, %t1, %type_matched)
1909)IR";1910
1911auto a = at::zeros({2, 2}, at::kFloat);1912a.to(at::kCPU);1913auto b = at::ones({3, 3}, at::kFloat);1914auto c = at::ones({2, 2, 2}, at::kFloat);1915
1916std::vector<IValue> args_correct = {a, b};1917std::vector<IValue> args_incorrect = {a, c};1918
1919testStaticRuntime(typecheck_ir, args_correct);1920testStaticRuntime(typecheck_ir, args_correct, args_incorrect);1921}
1922
1923TEST(StaticRuntime, Index) {1924const auto index_without_none_script = R"JIT(1925def forward(self, a: Tensor, idx: Tensor):
1926return a[idx].clone()
1927)JIT";1928
1929// Index with boolean mask1930auto a = at::arange(4, at::kFloat).view({2, 2});1931auto idx_a = torch::tensor({{0, 1}, {0, 0}}, at::kBool);1932std::vector<IValue> args_a{a, idx_a};1933
1934// Index with tensor1935auto b = at::arange(27, at::kFloat).view({3, 3, 3});1936auto idx_b = torch::tensor({0, 1, 2}, at::kLong);1937std::vector<IValue> args_b{b, idx_b};1938
1939testStaticRuntime(index_without_none_script, args_a);1940testStaticRuntime(index_without_none_script, args_a, args_b);1941
1942const auto index_with_none_script = R"JIT(1943def forward(self, a: Tensor, idx: Tensor, none: Optional[Tensor]):
1944return a[idx, none].clone()
1945)JIT";1946
1947// Index with None1948// When indexing with none, the shape of `f` becomes [2, 1, 2],1949// so the mask must be reshaped appropriately.1950auto f = at::arange(4, at::kFloat).view({2, 1, 2});1951auto idx_f_reshape = torch::tensor({{{0, 1}}, {{0, 0}}}, at::kBool);1952std::vector<IValue> args_f_with_none{f, idx_f_reshape};1953args_f_with_none.emplace_back();1954
1955testStaticRuntime(index_with_none_script, args_f_with_none);1956testStaticRuntime(1957index_with_none_script,1958args_f_with_none,1959{IValue(b), IValue(idx_b), IValue()});1960
1961const auto index_with_two_tensors_script = R"JIT(1962def forward(self, a: Tensor, idx_a: Tensor, idx_b: Tensor):
1963return a[idx_a, idx_b].clone()
1964)JIT";1965
1966// Index with multiple tensors1967const auto& c = a; // 2x2 tensor1968auto idx_c1 = torch::tensor({0, 0}, at::kLong);1969auto idx_c2 = torch::tensor({0}, at::kLong);1970std::vector<IValue> args_c{c, idx_c1, idx_c2};1971
1972const auto& d = b; // 3x3x3 tensor1973auto idx_d1 = torch::tensor({{0, 0, 2}, {0, 1, 1}}, at::kLong);1974auto idx_d2 = torch::tensor({{1, 1, 0}, {1, 0, 2}}, at::kLong);1975std::vector<IValue> args_d{d, idx_d1, idx_d2};1976
1977testStaticRuntime(index_with_two_tensors_script, args_c, args_d);1978}
1979
1980TEST(StaticRuntime, IndexSelect) {1981const std::string script = R"IR(1982graph(%self: Tensor, %dim: int, %index: Tensor):
1983%bias: None = prim::Constant()
1984%ret = aten::index_select(%self, %dim, %index)
1985%cloned = aten::clone(%ret, %bias)
1986return (%cloned)
1987)IR";1988
1989auto self0 = at::rand({6});1990auto dim0 = 0;1991auto index0 = at::randint(0, 5, {6}, torch::kInt32);1992std::vector<IValue> args{self0, dim0, index0};1993testStaticRuntime(script, args);1994
1995auto self1 = at::rand({128});1996auto dim1 = 0;1997auto index1 = at::randint(0, 127, {127}, torch::kInt32);1998std::vector<IValue> args2{self1, dim1, index1};1999testStaticRuntime(script, args, args2);2000}
2001
2002TEST(StaticRuntime, ClampMin) {2003const auto clamp_min_int_script = R"JIT(2004def forward(self, a: Tensor, b: int):
2005return torch.clamp_min(a, b).clone()
2006)JIT";2007
2008const auto clamp_min_float_script = R"JIT(2009def forward(self, a: Tensor, b: float):
2010return torch.clamp_min(a, b).clone()
2011)JIT";2012
2013auto a = at::randn({2, 2});2014auto b = at::randn({3, 3, 3});2015int scalar_int = 1;2016float scalar_float = 3.14;2017
2018std::vector<IValue> args_a_int{a, scalar_int};2019std::vector<IValue> args_b_int{b, scalar_int};2020
2021testStaticRuntime(clamp_min_int_script, args_a_int);2022testStaticRuntime(clamp_min_int_script, args_a_int, args_b_int);2023
2024std::vector<IValue> args_a_float{a, scalar_float};2025std::vector<IValue> args_b_float{b, scalar_float};2026
2027testStaticRuntime(clamp_min_float_script, args_a_float);2028testStaticRuntime(clamp_min_float_script, args_a_float, args_b_float);2029}
2030
2031TEST(StaticRuntime, Argmin) {2032const auto argmin_script = R"JIT(2033def forward(self, a: Tensor):
2034return torch.argmin(a).clone()
2035)JIT";2036
2037const auto argmin_with_dim_script = R"JIT(2038def forward(self, a: Tensor, dim: int):
2039return torch.argmin(a, dim).clone()
2040)JIT";2041
2042const auto argmin_with_keep_dim_script = R"JIT(2043def forward(self, a: Tensor, dim: int):
2044return torch.argmin(a, dim, True).clone()
2045)JIT";2046
2047auto a = at::randn({2, 2});2048auto b = at::randn({17, 2, 1});2049
2050testStaticRuntime(argmin_script, {a});2051testStaticRuntime(2052argmin_script,2053{a},2054{b},2055/* use_allclose */ false,2056/* use_equalnan */ false,2057/* check_resize */ false);2058
2059int dim_a = 0;2060int dim_b = 1;2061
2062std::vector<IValue> args_a{a, dim_a};2063std::vector<IValue> args_b{b, dim_b};2064
2065testStaticRuntime(argmin_with_dim_script, args_a);2066testStaticRuntime(argmin_with_dim_script, args_a, args_b);2067
2068testStaticRuntime(argmin_with_keep_dim_script, args_a);2069testStaticRuntime(argmin_with_keep_dim_script, args_a, args_b);2070}
2071
2072TEST(StaticRuntime, Softmax) {2073const auto softmax_script = R"JIT(2074def forward(self, a: Tensor, dim: int):
2075return torch.softmax(a, dim).clone()
2076)JIT";2077
2078const auto softmax_script_with_dtype = R"JIT(2079def forward(self, a: Tensor, dim: int, dtype: int):
2080return torch.softmax(a, dim, dtype=dtype).clone()
2081)JIT";2082
2083auto a = at::randn({2, 3});2084auto b = at::randn({3, 3, 3});2085
2086testStaticRuntime(softmax_script, {a, 0});2087testStaticRuntime(softmax_script, {a, 1});2088
2089testStaticRuntime(softmax_script, {b, 0});2090testStaticRuntime(softmax_script, {b, 1});2091testStaticRuntime(softmax_script, {b, 2});2092
2093testStaticRuntime(softmax_script_with_dtype, {a, 1, at::ScalarType::Float});2094testStaticRuntime(softmax_script_with_dtype, {b, 1, at::ScalarType::Float});2095}
2096
2097TEST(StaticRuntime, GetItem_Dict) {2098const auto getitem_dict_tensor_script = R"JIT(2099def forward(self, key: Tensor):
2100d = {key: 1}
2101return d[key]
2102)JIT";2103
2104const auto getitem_dict_int_script = R"JIT(2105def forward(self, key: int):
2106d = {key: 1}
2107return d[key]
2108)JIT";2109
2110const auto getitem_dict_str_script = R"JIT(2111def forward(self, key: str):
2112d = {key: 1}
2113return d[key]
2114)JIT";2115
2116int int_key = 0;2117std::string str_key = "str";2118
2119// No need to test these multiple times, args are not tensors2120testStaticRuntime(getitem_dict_int_script, {int_key});2121testStaticRuntime(getitem_dict_str_script, {str_key});2122
2123auto a = torch::tensor({1});2124auto b = torch::tensor({1, 1});2125
2126testStaticRuntime(getitem_dict_tensor_script, {a});2127testStaticRuntime(getitem_dict_tensor_script, {a}, {b});2128}
2129
2130TEST(StaticRuntime, GetItem_List) {2131const auto getitem_list_int_script = R"JIT(2132def forward(self, idx: int):
2133lst = [1, 2, 3]
2134return lst[idx]
2135)JIT";2136
2137const auto getitem_list_tensor_script = R"JIT(2138def forward(self, tensor: Tensor, idx: int):
2139lst = [tensor, tensor]
2140return lst[idx]
2141)JIT";2142
2143testStaticRuntime(getitem_list_int_script, {1});2144testStaticRuntime(getitem_list_int_script, {-1});2145
2146auto a = torch::tensor({1});2147auto b = torch::tensor({1, 1});2148
2149testStaticRuntime(getitem_list_tensor_script, {a, 1});2150testStaticRuntime(getitem_list_tensor_script, {a, 1}, {b, -1});2151}
2152
2153TEST(StaticRuntime, Transpose) {2154const auto transpose_script = R"JIT(2155def forward(self, a: Tensor, dim1: int, dim2: int):
2156return torch.transpose(a, dim1, dim2).clone()
2157)JIT";2158
2159auto a = at::randn({2, 2});2160int dim1_a = 0;2161int dim2_a = 1;2162std::vector<IValue> args_a{a, dim1_a, dim2_a};2163
2164auto b = at::randn({3, 3, 3});2165int dim1_b = 0;2166int dim2_b = 2;2167std::vector<IValue> args_b{b, dim1_b, dim2_b};2168
2169testStaticRuntime(transpose_script, args_a);2170testStaticRuntime(transpose_script, args_a, args_b);2171}
2172
2173TEST(StaticRuntime, Permute) {2174auto permute_script = R"JIT(2175def forward(self, a: Tensor, dims: List[int]):
2176return torch.permute(a, dims).clone()
2177)JIT";2178
2179auto a = at::randn({2, 2});2180c10::List<int64_t> dims_a{1, 0};2181std::vector<IValue> args_a{a, dims_a};2182
2183auto b = at::randn({3, 3, 3});2184c10::List<int64_t> dims_b{0, 2, 1};2185std::vector<IValue> args_b{b, dims_b};2186
2187testStaticRuntime(permute_script, args_a);2188testStaticRuntime(permute_script, args_a, args_b);2189
2190permute_script = R"JIT(2191def forward(self, a: Tensor, dims: List[int], shape: List[int]):
2192return torch.permute(a, dims).reshape(shape).clone()
2193)JIT";2194
2195a = at::randn({8, 16, 4});2196dims_a = {0, 2, 1};2197dims_b = {-1, 16};2198testStaticRuntime(permute_script, {a, dims_a, dims_b});2199}
2200
2201TEST(StaticRuntime, Slice) {2202const auto slice_script = R"JIT(2203def forward(self, a: Tensor, dim: int, start: int, end: int, step: int):
2204return a.slice(dim, start, end, step).clone()
2205)JIT";2206
2207auto a = at::randn({2, 2});2208int dim_a = 1;2209int start_a = 0;2210int end_a = 1;2211int step_a = 1;2212std::vector<IValue> args_a{a, dim_a, start_a, end_a, step_a};2213
2214auto b = at::randn({3, 3, 3});2215int dim_b = 2;2216int start_b = 0;2217int end_b = 1;2218int step_b = 2;2219std::vector<IValue> args_b{b, dim_b, start_b, end_b, step_b};2220
2221testStaticRuntime(slice_script, args_a);2222testStaticRuntime(slice_script, args_a, args_b);2223
2224const auto slice_script2 = R"JIT(2225def forward(self, a: Tensor, dim: int, step: int):
2226return a.slice(dim, None, None, step).clone()
2227)JIT";2228std::vector<IValue> args_c{b, dim_b, step_b};2229testStaticRuntime(slice_script2, args_c);2230}
2231
2232TEST(StaticRuntime, Narrow) {2233const auto narrow_with_int_script = R"JIT(2234def forward(self, a: Tensor, dim: int, start: int, length: int):
2235return a.narrow(dim, start, length).clone()
2236)JIT";2237
2238auto a = at::randn({5, 5});2239int dim_a = 0;2240int start_a_int = 3;2241int len_a = 2;2242std::vector<IValue> args_a{a, dim_a, start_a_int, len_a};2243
2244auto b = at::randn({5, 5, 5});2245int dim_b = 1;2246int start_b_int = 2;2247int len_b = 3;2248std::vector<IValue> args_b{b, dim_b, start_b_int, len_b};2249
2250testStaticRuntime(narrow_with_int_script, args_a);2251testStaticRuntime(narrow_with_int_script, args_a, args_b);2252}
2253
2254TEST(StaticRuntime, TupleUnpack) {2255const auto two_tuple_unpack_script = R"JIT(2256def forward(self, tup: Tuple[Tensor, Tensor]):
2257a, b = tup
2258return (a, b)
2259)JIT";2260
2261const auto three_tuple_unpack_script = R"JIT(2262def forward(self, tup: Tuple[Tensor, Tensor, Tensor]):
2263a, b, c = tup
2264return (a, b, c)
2265)JIT";2266
2267auto two_tup = c10::ivalue::Tuple::create({at::randn({1}), at::randn({1})});2268auto two_tup_large =2269c10::ivalue::Tuple::create({at::randn({2, 2}), at::randn({2, 2})});2270
2271auto three_tup = c10::ivalue::Tuple::create(2272{at::randn({1}), at::randn({1}), at::randn({1})});2273auto three_tup_large = c10::ivalue::Tuple::create(2274{at::randn({2, 2}), at::randn({2, 2}), at::randn({2, 2})});2275
2276testStaticRuntime(two_tuple_unpack_script, {two_tup});2277testStaticRuntime(two_tuple_unpack_script, {two_tup}, {two_tup_large});2278
2279testStaticRuntime(three_tuple_unpack_script, {three_tup});2280testStaticRuntime(three_tuple_unpack_script, {three_tup}, {three_tup_large});2281}
2282
2283TEST(StaticRuntime, Append) {2284const auto append_int_script = R"JIT(2285def forward(self, a: int):
2286lst = [1, 2, 3]
2287lst.append(a)
2288return lst
2289)JIT";2290
2291const auto append_tensor_script = R"JIT(2292def forward(self, a: Tensor):
2293lst = []
2294lst.append(a)
2295return lst
2296)JIT";2297
2298std::vector<IValue> args_int{1};2299
2300testStaticRuntime(append_int_script, args_int);2301
2302std::vector<IValue> args_tensor{at::randn({1})};2303std::vector<IValue> args_tensor_large{at::randn({2, 2})};2304
2305testStaticRuntime(append_tensor_script, args_tensor);2306testStaticRuntime(append_tensor_script, args_tensor, args_tensor_large);2307}
2308
2309TEST(StaticRuntime, QuantizedLinear) {2310const std::string quantize_script = R"IR(2311graph(%input: Tensor, %weights: Tensor):
2312%scale: float = prim::Constant[value=1.]()
2313%zero_point: int = prim::Constant[value=1]()
2314%bias: None = prim::Constant()
2315%packed_params = quantized::linear_prepack(%weights, %bias)
2316%1254 = quantized::linear(%input, %packed_params, %scale, %zero_point)
2317%1249: Tensor = aten::dequantize(%1254)
2318return (%1249)
2319)IR";2320at::Tensor weight =2321at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8);2322at::Tensor input =2323at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8);2324
2325at::Tensor weight_2 =2326at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8);2327at::Tensor input_2 =2328at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8);2329
2330testStaticRuntime(quantize_script, {input, weight}, {input_2, weight_2});2331}
2332
2333TEST(StaticRuntime, QuantizedLinearDynamicFp16) {2334const std::string quantized_linear_dynamic_fp16_script = R"IR(2335graph(%input: Tensor, %weights: Tensor):
2336%bias: None = prim::Constant()
2337%packed_params = quantized::linear_prepack_fp16(%weights, %bias)
2338%output = quantized::linear_dynamic_fp16(%input, %packed_params)
2339%ret = aten::clone(%output, %bias)
2340return (%ret)
2341)IR";2342at::Tensor weight = torch::randn({3, 2}, torch::kFloat);2343at::Tensor input = torch::randn({3, 2}, torch::kFloat);2344
2345at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);2346at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);2347
2348testStaticRuntime(2349quantized_linear_dynamic_fp16_script,2350{input, weight},2351{input_2, weight_2});2352}
2353
2354TEST(StaticRuntime, QuantizedLinearReluDynamicFp16) {2355const std::string quantized_linear_relu_dynamic_fp16_script = R"IR(2356graph(%input: Tensor, %weights: Tensor):
2357%bias: None = prim::Constant()
2358%packed_params = quantized::linear_prepack_fp16(%weights, %bias)
2359%output = quantized::linear_relu_dynamic_fp16(%input, %packed_params)
2360%ret = aten::clone(%output, %bias)
2361return (%ret)
2362)IR";2363at::Tensor weight = torch::randn({3, 2}, torch::kFloat);2364at::Tensor input = torch::randn({3, 2}, torch::kFloat);2365
2366at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);2367at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);2368
2369testStaticRuntime(2370quantized_linear_relu_dynamic_fp16_script,2371{input, weight},2372{input_2, weight_2});2373}
2374
2375TEST(StaticRuntime, VarStack) {2376const auto var_stack_script = R"JIT(2377def forward(self, inp1: Tensor, inp2: Tensor, dim: int):
2378return torch.stack([inp1, inp2], dim).clone()
2379)JIT";2380
2381// 2D tensors - stack dim = 02382std::vector<IValue> args1 = {at::randn({6, 6}), at::randn({6, 6}), 0};2383testStaticRuntime(var_stack_script, args1);2384
2385// 3D tensors - stack dim = 12386std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 1};2387testStaticRuntime(var_stack_script, args2);2388
2389// 3D tensors - stack dim = 22390std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 2};2391testStaticRuntime(var_stack_script, args3);2392
2393// Negative dim2394std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), -1};2395testStaticRuntime(var_stack_script, args4);2396
2397// Non-serial path2398std::vector<IValue> args5 = {at::randn({1, 2, 3}), at::randn({1, 2, 3}), 3};2399testStaticRuntime(var_stack_script, args5);2400
2401// Fast path2402std::vector<IValue> args6 = {at::randn({1}), at::randn({1}), 0};2403testStaticRuntime(var_stack_script, args6);2404
2405testStaticRuntime(var_stack_script, args1, args2);2406}
2407
2408TEST(StaticRuntime, FmodTensor) {2409const auto fmod_tensor = R"JIT(2410def forward(self, a: Tensor, b: Tensor):
2411return torch.fmod(a, b).clone()
2412)JIT";2413
2414// fmod tensor version2415auto a = at::randn({2, 3});2416auto b = at::randn({2, 3});2417std::vector<IValue> args0{a, b};2418testStaticRuntime(fmod_tensor, args0);2419
2420// check for dynamic shapes2421auto c = at::randn({4, 3, 2});2422auto d = at::randn({4, 3, 2});2423std::vector<IValue> args1{c, d};2424testStaticRuntime(fmod_tensor, args0, args1);2425}
2426
2427TEST(StaticRuntime, FmodScalar) {2428const auto fmod_scalar = R"JIT(2429def forward(self, a: Tensor, b: int):
2430return torch.fmod(a, b).clone()
2431)JIT";2432
2433auto a = at::randn({2, 3});2434
2435// fmod scalar version2436std::vector<IValue> args2{a, 3};2437testStaticRuntime(fmod_scalar, args2);2438
2439// check for dynamic shapes2440auto c = at::randn({4, 3, 2});2441std::vector<IValue> args3{c, 4};2442testStaticRuntime(fmod_scalar, args2, args3);2443
2444// test int32 version2445a = at::randint(-100, 100, {2, 3}, at::kInt);2446c = at::randint(-100, 100, {4, 3, 2}, at::kInt);2447testStaticRuntime(fmod_scalar, {a, 3});2448testStaticRuntime(fmod_scalar, {a, 3}, {c, 4});2449}
2450
2451TEST(StaticRuntime, QEmbeddingBagBytePrepack) {2452const std::string embedding_bag_byte_prepack_script = R"IR(2453graph(%input: Tensor):
2454%none : None = prim::Constant()
2455%output: Tensor = quantized::embedding_bag_byte_prepack(%input)
2456%res: Tensor = aten::clone(%output, %none)
2457return (%res)
2458)IR";2459
2460auto a = torch::randn({8, 16}, at::ScalarType::Float);2461auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float);2462
2463testStaticRuntime(embedding_bag_byte_prepack_script, {a});2464testStaticRuntime(embedding_bag_byte_prepack_script, {a}, {b});2465}
2466
2467TEST(StaticRuntime, QEmbeddingBagByteUnpack) {2468const auto src = R"IR(2469graph(%input: Tensor):
2470%none : None = prim::Constant()
2471%weight: Tensor = quantized::embedding_bag_byte_prepack(%input)
2472%output: Tensor = quantized::embedding_bag_byte_unpack(%weight)
2473%res: Tensor = aten::clone(%output, %none)
2474return (%res)
2475)IR";2476
2477auto a = torch::randn({8, 16}, at::ScalarType::Float);2478auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float);2479
2480testStaticRuntime(src, {a});2481testStaticRuntime(src, {a}, {b});2482}
2483
2484TEST(StaticRuntime, LinalgNorm_ScalarOrd) {2485const auto linalg_norm_ord_scalar = R"JIT(2486def forward(self, a: Tensor, ord: int, dim: List[int], keepdim: bool, dtype: int):
2487return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
2488)JIT";2489
2490auto a = at::randn({2, 3});2491auto dim = std::vector<int64_t>({1});2492auto dtype = at::ScalarType::Float;2493
2494std::vector<IValue> args0{a, 4, dim, true, dtype};2495testStaticRuntime(linalg_norm_ord_scalar, args0);2496
2497auto b = at::randn({3, 2, 6});2498std::vector<IValue> args1{b, 4, dim, true, dtype};2499testStaticRuntime(linalg_norm_ord_scalar, args0, args1);2500}
2501
2502TEST(StaticRuntime, LinalgNorm_StringOrd) {2503const auto linalg_norm_ord_str = R"JIT(2504def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int):
2505return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
2506)JIT";2507
2508auto a = at::randn({2, 3});2509auto dim = std::vector<int64_t>({0, 1});2510auto dtype = at::ScalarType::Float;2511
2512std::vector<IValue> args0{a, "fro", dim, true, dtype};2513testStaticRuntime(linalg_norm_ord_str, args0);2514
2515auto b = at::randn({3, 2, 17});2516std::vector<IValue> args1{b, "fro", dim, true, dtype};2517testStaticRuntime(linalg_norm_ord_str, args0, args1);2518}
2519
2520TEST(StaticRuntime, Index_Put) {2521const auto index_put_str = R"JIT(2522def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool):
2523return torch.index_put(a, indices, values, accumulate).clone()
2524)JIT";2525
2526auto a = at::randn({2});2527auto indices_a = std::make_tuple(torch::tensor({0}, at::kLong));2528auto values_a = at::randn({1});2529
2530std::vector<IValue> args0{a, indices_a, values_a, false};2531testStaticRuntime(index_put_str, args0);2532
2533const auto index_put_non_optional_str = R"JIT(2534def forward(self, a: Tensor, indices: List[Tensor], values: Tensor, accumulate: bool):
2535return torch.index_put(a, indices, values, accumulate).clone()
2536)JIT";2537
2538auto indices_b = c10::List<at::Tensor>{torch::tensor({0}, at::kLong)};2539std::vector<IValue> args1{a, indices_b, values_a, false};2540testStaticRuntime(index_put_non_optional_str, args1);2541
2542const auto index_put_list_construct = R"JIT(2543def forward(self, a: Tensor, indices: Tensor, values: Tensor, accumulate: bool):
2544indices: List[Optional[Tensor]] = [indices]
2545return torch.index_put(a, indices, values, accumulate).clone()
2546)JIT";2547
2548std::vector<IValue> args2{a, torch::tensor({0}, at::kLong), values_a, false};2549testStaticRuntime(index_put_list_construct, args2);2550}
2551
2552TEST(StaticRuntime, Item) {2553const auto item_str = R"JIT(2554def forward(self, a: Tensor):
2555return torch.item(a)
2556)JIT";2557
2558auto a = at::randn({1});2559
2560std::vector<IValue> args0{a};2561testStaticRuntime(item_str, args0);2562}
2563
2564TEST(StaticRuntime, Tensor_Split) {2565const auto tensor_split_str1 = R"JIT(2566def forward(self, a: Tensor, sections: int, dim: int):
2567return torch.tensor_split(a, sections, dim)
2568)JIT";2569std::vector<IValue> args1{at::randn({8}), 3, 0};2570
2571const auto tensor_split_str2 = R"JIT(2572def forward(self, a: Tensor, sections: Tensor, dim: int):
2573return torch.tensor_split(a, sections, dim)
2574)JIT";2575std::vector<IValue> args2{at::randn({8}), torch::tensor(3), 0};2576
2577const auto tensor_split_str3 = R"JIT(2578def forward(self, a: Tensor, indices: List[int], dim: int):
2579return torch.tensor_split(a, indices, dim)
2580)JIT";2581std::vector<IValue> args3{at::randn({8}), c10::List<int64_t>({1, 6}), 0};2582
2583testStaticRuntime(tensor_split_str1, args1);2584testStaticRuntime(tensor_split_str2, args2);2585testStaticRuntime(tensor_split_str3, args3);2586}
2587
2588TEST(StaticRuntime, JIT_Aten_Cpu) {2589const std::string script = R"IR(2590graph(%a: Tensor):
2591%1 : int = prim::Constant[value=0]()
2592%aa: Tensor = aten::add(%a, %a, %1)
2593%ret: Tensor = aten::cpu(%aa)
2594return (%ret)
2595)IR";2596
2597auto graph = std::make_shared<Graph>();2598std::unordered_map<std::string, Value*> vmap;2599vmap.reserve(0);2600parseIR(script, graph.get(), vmap);2601torch::jit::StaticModule smodule(graph);2602
2603auto a = at::randn({2, 4});2604std::vector<IValue> args0{a};2605
2606testStaticRuntime(script, args0);2607}
2608
2609TEST(StaticRuntime, JIT_Aten_Numel) {2610const std::string script = R"IR(2611graph(%a: Tensor):
2612%1 : int = prim::Constant[value=0]()
2613%aa: Tensor = aten::add(%a, %a, %1)
2614%ret: int = aten::numel(%aa)
2615return (%ret)
2616)IR";2617
2618auto graph = std::make_shared<Graph>();2619std::unordered_map<std::string, Value*> vmap;2620vmap.reserve(0);2621parseIR(script, graph.get(), vmap);2622torch::jit::StaticModule smodule(graph);2623
2624auto a = at::randn({2, 4});2625std::vector<IValue> args0{a};2626
2627testStaticRuntime(script, args0);2628}
2629
2630TEST(StaticRuntime, JIT_Aten_List) {2631const auto script_str = R"IR(2632graph(%a: str):
2633%ret: str[] = aten::list(%a)
2634return (%ret)
2635)IR";2636std::string a = "abcd";2637std::vector<IValue> args0{a};2638testStaticRuntime(script_str, args0);2639
2640// Update the result of aten::list to ensure that a deep copy2641// took place2642const auto script_list = R"IR(2643graph(%a : int[]):
2644%idx : int = prim::Constant[value=0]()
2645%value : int = prim::Constant[value=42]()
2646%res : int[] = aten::list(%a)
2647%updated : int[] = aten::_set_item(%res, %idx, %value)
2648return (%res, %a)
2649)IR";2650
2651std::vector<IValue> args1{c10::List<int64_t>{1, 2, 3}};2652testStaticRuntime(script_list, args1);2653}
2654
2655TEST(StaticRuntime, JIT_Aten_Range_Length) {2656const std::string script = R"IR(2657graph(%lo: int, %hi: int, %step: int):
2658%1 : int = prim::Constant[value=0]()
2659%ret: int = aten::__range_length(%lo, %hi, %step)
2660return (%ret)
2661)IR";2662
2663auto graph = std::make_shared<Graph>();2664std::unordered_map<std::string, Value*> vmap;2665vmap.reserve(0);2666parseIR(script, graph.get(), vmap);2667torch::jit::StaticModule smodule(graph);2668
2669std::vector<IValue> args0{0, 10, 2};2670
2671testStaticRuntime(script, args0);2672}
2673
2674TEST(StaticRuntime, Cat) {2675const std::string cat_script = R"IR(2676graph(%a: Tensor, %b: Tensor, %dim: int):
2677%ten_list: Tensor[] = prim::ListConstruct(%a, %b)
2678%1 : int = prim::Constant[value=0]()
2679%2 : int = prim::Constant[value=1]()
2680%3 : int = prim::Constant[value=1]()
2681%ten_list2 : Tensor[] = aten::slice(%ten_list, %1, %2, %3)
2682%ret: Tensor = aten::cat(%ten_list2, %dim)
2683return (%ret)
2684)IR";2685
2686auto graph = std::make_shared<Graph>();2687std::unordered_map<std::string, Value*> vmap;2688parseIR(cat_script, graph.get(), vmap);2689torch::jit::StaticModule smodule(graph);2690ASSERT_TRUE(getNodeWithKind(smodule, "aten::cat"));2691
2692auto a = at::randn({2, 4});2693auto b = at::randn({3, 4});2694std::vector<IValue> args0{a, b, 0};2695
2696testStaticRuntime(cat_script, args0);2697
2698auto c = at::randn({3, 4});2699auto d = at::randn({3, 5});2700std::vector<IValue> args1{c, d, 1};2701testStaticRuntime(cat_script, args0, args1);2702
2703std::vector<IValue> args_dim_negative{c, d, -1};2704testStaticRuntime(cat_script, args_dim_negative);2705}
2706
2707TEST(StaticRuntime, Cumsum) {2708const auto cumsum_script = R"JIT(2709def forward(self, a: Tensor, dim: int):
2710return torch.cumsum(a, dim).clone()
2711)JIT";2712
2713auto a = at::randn({2, 3});2714std::vector<IValue> args0{a, 0};2715testStaticRuntime(cumsum_script, args0);2716
2717auto b = at::randn({3, 6});2718std::vector<IValue> args1{b, 1};2719testStaticRuntime(cumsum_script, args0, args1);2720}
2721
2722TEST(StaticRuntime, CumsumDtype) {2723const auto cumsum_script_dtype = R"JIT(2724def forward(self, a: Tensor, dim: int, dtype: int):
2725return torch.cumsum(a, dim, dtype=dtype).clone()
2726)JIT";2727
2728auto a = at::randn({1, 2});2729auto dtype = at::ScalarType::Float;2730std::vector<IValue> args0{a, 0, dtype};2731testStaticRuntime(cumsum_script_dtype, args0);2732
2733auto b = at::randn({3, 6});2734std::vector<IValue> args1{b, 1, dtype};2735testStaticRuntime(cumsum_script_dtype, args0, args1);2736}
2737
2738TEST(StaticRuntime, Nonzero) {2739const auto nonzero_tensor = R"JIT(2740def forward(self, input: Tensor):
2741a = torch.nonzero(input).clone()
2742return (a)
2743)JIT";2744
2745auto a = at::randint(0, 2, {2, 3});2746testStaticRuntime(nonzero_tensor, {a});2747
2748auto b = at::randint(0, 2, {4, 3, 2});2749testStaticRuntime(nonzero_tensor, {a}, {b});2750}
2751
2752TEST(StaticRuntime, SignedLog1p) {2753const std::string signed_log1p_script = R"IR(2754graph(%input):
2755%0 : Tensor = aten::sign(%input)
2756%1 : Tensor = aten::abs(%input)
2757%2 : Tensor = aten::log1p(%1)
2758%3 : Tensor = aten::mul(%0, %2)
2759%none : NoneType = prim::Constant()
2760%res : Tensor = aten::clone(%3, %none)
2761return (%res)
2762)IR";2763
2764std::vector<IValue> args1 = {at::randn({2, 2})};2765testStaticRuntime(signed_log1p_script, args1, {}, true);2766
2767std::vector<IValue> args2 = {at::randn({3, 3, 3})};2768testStaticRuntime(signed_log1p_script, args1, args2, true);2769}
2770
2771TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithImmutableInputDict) {2772const auto getitem_immutable_input_dict_script = R"JIT(2773def forward(self, input: Dict[int, Tensor]):
2774a = input[0]
2775b = input[1]
2776c = a + b
2777return c.clone()
2778)JIT";2779
2780script::Module module("module");2781module.define(getitem_immutable_input_dict_script);2782torch::jit::StaticModule smodule(module);2783EXPECT_FALSE(hasNodeWithKind(smodule, "aten::__getitem__"));2784EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::dict_unpack"));2785
2786auto a = at::randn({2, 4});2787auto b = at::randn({2, 4});2788c10::Dict<c10::IValue, c10::IValue> dict(2789c10::IntType::get(), c10::TensorType::get());2790dict.insert(0, a);2791dict.insert(1, b);2792testStaticRuntime(getitem_immutable_input_dict_script, {dict});2793
2794c10::Dict<c10::IValue, c10::IValue> dict0(2795c10::IntType::get(), c10::TensorType::get());2796auto a0 = at::randn({3, 4});2797auto b0 = at::randn({3, 4});2798dict0.insert(0, a0);2799dict0.insert(1, b0);2800testStaticRuntime(getitem_immutable_input_dict_script, {dict0});2801}
2802
2803TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithMutableInputDict) {2804const auto getitem_mutable_input_dict_script = R"JIT(2805def forward(self, input: Dict[int, Tensor]):
2806a = input[0]
2807input[1] = a
2808b = input[1]
2809c = a + b
2810return c.clone()
2811)JIT";2812
2813script::Module module("module");2814module.define(getitem_mutable_input_dict_script);2815torch::jit::StaticModule smodule(module);2816EXPECT_TRUE(hasNodeWithKind(smodule, "aten::__getitem__"));2817EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::dict_unpack"));2818}
2819
2820TEST(StaticRuntime, VarTupleUnpack) {2821const auto var_tuple_unpack_script = R"JIT(2822def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]):
2823a, b = input_0
2824c, d = input_1
2825res = a * c + b * d
2826return res.clone()
2827)JIT";2828
2829script::Module module("module");2830module.define(var_tuple_unpack_script);2831torch::jit::StaticModule smodule(module);2832EXPECT_FALSE(hasNodeWithKind(smodule, "prim::TupleUnpack"));2833EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack"));2834
2835auto a = at::randn({2, 2});2836auto b = at::randn({3, 3, 3});2837std::vector<IValue> args1{2838c10::ivalue::Tuple::create(a, a), c10::ivalue::Tuple::create(1, 2)};2839std::vector<IValue> args2{2840c10::ivalue::Tuple::create(b, b), c10::ivalue::Tuple::create(1, 2)};2841
2842testStaticRuntime(var_tuple_unpack_script, args1);2843testStaticRuntime(var_tuple_unpack_script, args1, args2);2844}
2845
2846TEST(StaticRuntime, VarTupleUnpack_NotApplied) {2847const auto var_tuple_unpack_not_applied_script = R"JIT(2848def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]):
2849a, b = input_0
2850x = a + b
2851c, d = input_1
2852res = a * c + b * d + x
2853return res.clone()
2854)JIT";2855
2856script::Module module("module");2857// In this script, the optimization is not applied since there is a2858// computation between the TupleUnpack nodes.2859module.define(var_tuple_unpack_not_applied_script);2860torch::jit::StaticModule smodule(module);2861EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack"));2862EXPECT_TRUE(hasNodeWithKind(smodule, "prim::TupleUnpack"));2863}
2864
2865TEST(StaticRuntime, RemainderTensor) {2866const auto remainder_tensor = R"JIT(2867def forward(self, x, y):
2868return torch.remainder(x, y).clone()
2869)JIT";2870
2871std::vector<IValue> args1 = {2872at::randint(0, 10, {2, 2}), at::randint(1, 10, {2, 2})};2873std::vector<IValue> args2 = {2874at::randint(0, 10, {3, 6}), at::randint(1, 10, {3, 6})};2875
2876// Use allclose and equalnan since outputs may be NaN.2877testStaticRuntime(2878remainder_tensor,2879args1,2880/*args2*/ {},2881/*use_alloclose*/ true,2882/*use_equalnan*/ true);2883testStaticRuntime(2884remainder_tensor,2885args1,2886args2,2887/*use_allclose*/ true,2888/*use_equalnan*/ true);2889}
2890
2891TEST(StaticRuntime, RemainderScalar) {2892const auto remainder_scalar = R"JIT(2893def forward(self, x, y: int):
2894return torch.remainder(x, y).clone()
2895)JIT";2896
2897std::vector<IValue> args1 = {at::randint(0, 10, {2, 2}), 4};2898std::vector<IValue> args2 = {at::randint(0, 10, {3, 6}), 4};2899
2900// Use allclose and equalnan since outputs may be NaN.2901testStaticRuntime(2902remainder_scalar,2903args1,2904/*args2*/ {},2905/*use_alloclose*/ true,2906/*use_equalnan*/ true);2907testStaticRuntime(2908remainder_scalar,2909args1,2910args2,2911/*use_allclose*/ true,2912/*use_equalnan*/ true);2913}
2914
2915TEST(StaticRuntime, Where) {2916const auto where_script = R"JIT(2917def forward(self, x, y):
2918return torch.where(x > 0, x, y).clone()
2919)JIT";2920
2921std::vector<IValue> args1 = {at::randn({2, 2}), at::randn({2, 2})};2922std::vector<IValue> args2 = {at::randn({8, 10}), at::randn({8, 10})};2923
2924testStaticRuntime(where_script, args1);2925testStaticRuntime(where_script, args1, args2);2926}
2927
2928TEST(StaticRuntime, WhereBroadcast) {2929const auto where_script = R"JIT(2930def forward(self, cond_1d, x, y):
2931shape = [-1] + [1] * (x.dim() - 1)
2932cond = cond_1d.view(shape)
2933return torch.where(cond, x, y).clone()
2934)JIT";2935
2936std::vector<IValue> args1 = {2937at::tensor({0, 1}).to(at::kBool), at::randn({2, 2}), at::randn({2, 2})};2938std::vector<IValue> args2 = {2939at::tensor({1, 0, 0}).to(at::kBool),2940at::randn({3, 6}),2941at::randn({3, 6})};2942
2943testStaticRuntime(where_script, args1);2944testStaticRuntime(where_script, args1, args2);2945}
2946
2947TEST(StaticRuntime, View) {2948// Note that clone is not technically necessary here since this is not2949// an out variant, but it suppresses warnings about only have one op2950// in testStaticRuntime2951const auto src = R"IR(2952graph(%input : Tensor, %shape : int[]):
2953%none : NoneType = prim::Constant()
2954%view : Tensor = aten::view(%input, %shape)
2955%res : Tensor = aten::clone(%view, %none)
2956return (%res)
2957)IR";2958
2959std::vector<IValue> args1{at::randn({2, 2}), c10::List<int64_t>(4)};2960std::vector<IValue> args2{at::randn({2, 2, 2}), c10::List<int64_t>({4, 2})};2961
2962testStaticRuntime(src, args1);2963testStaticRuntime(src, args1, args2);2964}
2965
2966TEST(StaticRuntime, Size) {2967const auto src_with_dim = R"JIT(2968def forward(self, x, dim: int):
2969return x.size(dim)
2970)JIT";2971
2972const auto src_no_dim = R"JIT(2973def forward(self, x):
2974return x.size()
2975)JIT";2976
2977std::vector<IValue> args1{at::randn({1}), 0};2978std::vector<IValue> args2{at::randn({1}), -1};2979std::vector<IValue> args3{at::randn({2, 4}), 1};2980std::vector<IValue> args_no_dim{at::randn({2, 4})};2981
2982testStaticRuntime(src_with_dim, args1);2983testStaticRuntime(src_with_dim, args2);2984testStaticRuntime(src_with_dim, args1, args3);2985testStaticRuntime(src_no_dim, args_no_dim);2986}
2987
2988TEST(StaticRuntime, Squeeze) {2989// Note: this is a native op, not an out variant, but clone anyways2990// to silence warnings in testStaticRuntime2991const auto src = R"JIT(2992def forward(self, inp, dim: int):
2993return inp.squeeze(dim).clone()
2994)JIT";2995
2996const auto a = at::randn({2, 2});2997const auto b = at::randn({3, 2, 3});2998
2999testStaticRuntime(src, {a, 0});3000testStaticRuntime(src, {a, 1});3001testStaticRuntime(src, {a, -1}, {b, 2});3002}
3003
3004TEST(StaticRuntime, NumToTensorScalar) {3005const auto num_to_tensor_ir = R"IR(3006graph(%1 : int):
3007%2 : NoneType = prim::Constant()
3008%3 : Tensor = prim::NumToTensor(%1)
3009%4 : Tensor = aten::clone(%3, %2)
3010return (%4)
3011)IR";3012
3013IValue arg{5};3014std::vector<IValue> args = {arg};3015testStaticRuntime(num_to_tensor_ir, args);3016}
3017
3018TEST(StaticRuntime, NumToTensorFalse) {3019const auto num_to_tensor_ir = R"IR(3020graph(%1 : bool):
3021%2 : NoneType = prim::Constant()
3022%3 : Tensor = prim::NumToTensor(%1)
3023%4 : Tensor = aten::clone(%3, %2)
3024return (%4)
3025)IR";3026
3027IValue arg{false};3028std::vector<IValue> args = {arg};3029testStaticRuntime(num_to_tensor_ir, args);3030}
3031
3032TEST(StaticRuntime, NumToTensorTrue) {3033const auto num_to_tensor_ir = R"IR(3034graph(%1 : bool):
3035%2 : NoneType = prim::Constant()
3036%3 : Tensor = prim::NumToTensor(%1)
3037%4 : Tensor = aten::clone(%3, %2)
3038return (%4)
3039)IR";3040
3041IValue arg{true};3042std::vector<IValue> args = {arg};3043testStaticRuntime(num_to_tensor_ir, args);3044}
3045
3046TEST(StaticRuntime, Split) {3047const auto src = R"JIT(3048def forward(self, inp, split_size: int, dim: int):
3049return inp.split(split_size, dim)
3050)JIT";3051
3052const auto a = at::randn({2, 2});3053const auto b = at::randn({2, 2, 2});3054
3055testStaticRuntime(src, {a, 1, 0});3056testStaticRuntime(src, {a, 1, 1});3057testStaticRuntime(src, {a, 2, -1}, {b, 2, 2});3058}
3059
3060TEST(StaticRuntime, SplitWithSizes) {3061const auto src = R"JIT(3062def forward(self, inp, split_sizes: List[int], dim: int):
3063return inp.split(split_sizes, dim)
3064)JIT";3065
3066const auto a = at::randn({2, 2});3067const auto b = at::randn({2, 2, 2});3068const auto split_sizes = c10::List<int64_t>{1, 1};3069
3070testStaticRuntime(src, {a, split_sizes, 0});3071testStaticRuntime(src, {a, split_sizes, 1});3072testStaticRuntime(src, {a, split_sizes, -1}, {b, split_sizes, 2});3073}
3074
3075namespace {3076
3077void maybe_throw(bool should_throw) {3078if (should_throw) {3079throw std::runtime_error("test exception");3080}3081}
3082
3083TORCH_LIBRARY(static_runtime_tests, m) {3084// Conservative so this op doesn't get deleted by dead3085// code elimination3086m.def(torch::schema(3087"static_runtime_tests::maybe_throw(bool throw) -> ()",3088at::AliasAnalysisKind::CONSERVATIVE));3089m.impl("maybe_throw", maybe_throw);3090}
3091
3092} // namespace3093
3094TEST(StaticRuntime, ModelCrashOnFirstRun) {3095const auto src = R"JIT(3096graph(%0: Tensor, %throw: bool):
3097%1: Tensor = aten::mul(%0, %0)
3098static_runtime_tests::maybe_throw(%throw)
3099%2: Tensor = aten::mul(%1, %1)
3100%3: Tensor = aten::mul(%2, %2)
3101return (%3)
3102)JIT";3103
3104auto graph = getGraphFromIR(src);3105auto static_module = StaticModule(graph);3106auto& runtime = static_module.runtime();3107
3108std::vector<IValue> args_crash{at::randn({1}), true};3109std::vector<IValue> args_no_crash{at::randn({1}), false};3110EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);3111
3112// The run didn't finish, we didn't allocate the memory planner3113EXPECT_EQ(runtime.get_memory_planner(), nullptr);3114runtime.check_for_memory_leak();3115
3116// We guarantee that the runtime is still usable after the crash.3117// Run again to verify this.3118compareResultsWithJIT(runtime, graph, args_no_crash);3119EXPECT_NE(runtime.get_memory_planner(), nullptr);3120}
3121
3122TEST(StaticRuntime, ModelCrashOnSecondRun) {3123const auto src = R"JIT(3124graph(%0: Tensor, %throw: bool):
3125%1: Tensor = aten::mul(%0, %0)
3126static_runtime_tests::maybe_throw(%throw)
3127%2: Tensor = aten::mul(%1, %1)
3128%3: Tensor = aten::mul(%2, %2)
3129return (%3)
3130)JIT";3131
3132auto graph = getGraphFromIR(src);3133auto static_module = StaticModule(graph);3134auto& runtime = static_module.runtime();3135
3136std::vector<IValue> args_crash{at::randn({1}), true};3137std::vector<IValue> args_no_crash{at::randn({1}), false};3138runtime(args_no_crash, {});3139EXPECT_NE(runtime.get_memory_planner(), nullptr);3140runtime.check_for_memory_leak();3141
3142EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);3143runtime.check_for_memory_leak();3144
3145// We guarantee that the runtime is still usable after the crash.3146// Run again to verify this.3147compareResultsWithJIT(runtime, graph, args_no_crash);3148}
3149
3150TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrows) {3151const auto src = R"JIT(3152graph(%0: Tensor):
3153%1: Tensor = aten::mul(%0, %0)
3154%2: Tensor = aten::mul(%1, %1)
3155%3: bool = prim::Constant[value=1]()
3156%4: Tensor = static_runtime::select_tensor(%1, %2, %3)
3157static_runtime_tests::maybe_throw(%3)
3158return (%4)
3159)JIT";3160auto graph = getGraphFromIR(src);3161auto static_module = StaticModule(graph);3162auto& runtime = static_module.runtime();3163
3164std::vector<IValue> args{at::randn({1})};3165EXPECT_THROW(runtime(args), std::runtime_error);3166}
3167
3168TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrowedInputs) {3169const auto src = R"JIT(3170graph(%0: Tensor, %1: Tensor):
3171%2: bool = prim::Constant[value=1]()
3172%3: Tensor = static_runtime::select_tensor(%0, %1, %2)
3173static_runtime_tests::maybe_throw(%2)
3174return (%3)
3175)JIT";3176auto graph = getGraphFromIR(src);3177auto static_module = StaticModule(graph);3178auto& runtime = static_module.runtime();3179
3180std::vector<IValue> args{at::randn({1}), at::randn({1})};3181EXPECT_THROW(runtime(std::move(args)), std::runtime_error);3182}
3183
3184TEST(StaticRuntime, ReplaceWithMaybeCopy) {3185const std::string to = R"IR(3186graph(%0 : Tensor):
3187%1: int = prim::Constant[value=4]()
3188%2: bool = prim::Constant[value=0]()
3189%3: None = prim::Constant()
3190%res : Tensor = aten::to(%0, %1, %2, %2, %3)
3191return (%res)
3192)IR";3193
3194at::Tensor a = at::tensor({1.1, 2.2, 3.3, 4.0}, at::ScalarType::Float);3195std::vector<IValue> args{a};3196auto g = std::make_shared<torch::jit::Graph>();3197torch::jit::parseIR(to, g.get());3198
3199// Jit Interpreter.3200Stack stack(args);3201torch::jit::GraphExecutor graph_exec(g, "");3202graph_exec.run(stack);3203ASSERT_EQ(stack.size(), 1);3204auto expected = stack[0].toTensor();3205
3206// Static Runtime.3207torch::jit::StaticModule smodule(g);3208auto actual = smodule(args, {}).toTensor();3209smodule.runtime().check_for_memory_leak();3210
3211EXPECT_TRUE(expected.equal(actual));3212
3213// Make a fresh graph to ensure the pass works in isolation3214auto new_graph = std::make_shared<torch::jit::Graph>();3215torch::jit::parseIR(to, new_graph.get());3216ReplaceWithMaybeCopy(new_graph);3217EXPECT_FALSE(hasNodeWithKind(new_graph, "aten::to"));3218EXPECT_TRUE(3219hasNodeWithKind(new_graph, "static_runtime::to_maybe_copy_out"));3220}
3221
3222TEST(StaticRuntime, Int) {3223const auto src = R"JIT(3224def forward(self, x):
3225return int(x) + int(x)
3226)JIT";3227std::vector<IValue> args{at::tensor({3.14})};3228testStaticRuntime(src, args);3229}
3230
3231TEST(StaticRuntime, ReturnConstant) {3232const auto src = R"JIT(3233def forward(self):
3234return 1
3235)JIT";3236
3237testStaticRuntime(src, {});3238}
3239
3240TEST(StaticRuntime, SimpleIf) {3241const auto src = R"JIT(3242def forward(self, cond: bool, x):
3243if cond:
3244return torch.mul(x, 42).clone()
3245else:
3246return x.clone()
3247)JIT";3248
3249std::vector<IValue> args_false{false, at::randn({1})};3250std::vector<IValue> args_true{true, at::randn({1})};3251std::vector<IValue> args_big_tensor{true, at::randn({3, 3, 3})};3252
3253testStaticRuntime(src, args_false);3254testStaticRuntime(src, args_true);3255testStaticRuntime(src, args_true, args_big_tensor);3256}
3257
3258TEST(StaticRuntime, NestedIf) {3259const auto src = R"JIT(3260def forward(self, cond1: bool, cond2: bool, x):
3261y = x * 42
3262if cond1:
3263y = y * y
3264if cond2:
3265y += x
3266else:
3267if cond2:
3268return x.clone()
3269
3270return y.clone()
3271)JIT";3272
3273for (auto cond1 : {true, false}) {3274for (auto cond2 : {true, false}) {3275std::vector<IValue> args1{cond1, cond2, at::randn({1})};3276std::vector<IValue> args2{cond1, cond2, at::randn({3, 3, 3})};3277testStaticRuntime(src, args1, args2);3278}3279}3280}
3281
3282TEST(StaticRuntime, DeeplyNestedIf) {3283const auto src = R"JIT(3284def forward(self, cond1: bool, cond2: bool, cond3: bool, x):
3285y = x * 42
3286if cond1:
3287y = y * y
3288if cond2:
3289y += x
3290
3291if cond2 and cond3:
3292y += 1
3293
3294if cond2:
3295if cond3:
3296y += 2
3297else:
3298y = y * y
3299y += 4
3300else:
3301if cond2:
3302return x.clone()
3303if cond3 or cond2:
3304y += 42
3305
3306return y.clone()
3307)JIT";3308
3309for (auto cond1 : {true, false}) {3310for (auto cond2 : {true, false}) {3311for (auto cond3 : {true, false}) {3312std::vector<IValue> args1{cond1, cond2, cond3, at::randn({1})};3313std::vector<IValue> args2{cond1, cond2, cond3, at::randn({3, 3, 3})};3314testStaticRuntime(src, args1, args2);3315}3316}3317}3318}
3319
3320TEST(StaticRuntime, BasicForLoop) {3321const auto src = R"JIT(3322def forward(self, x, loop_max: int):
3323y = x.clone()
3324for i in range(loop_max):
3325y += 1
3326return y
3327)JIT";3328
3329std::vector<IValue> args1{at::randn({1}), 10};3330std::vector<IValue> args2{at::randn({3, 3, 3}), 10};3331
3332testStaticRuntime(src, args1, args2);3333}
3334
3335TEST(StaticRuntime, BasicWhileLoop) {3336const auto src = R"JIT(3337def forward(self, x, loop_max: int):
3338y = x.clone()
3339loop_count = 0
3340while loop_count < loop_max:
3341y += 1
3342loop_count += 1
3343return y
3344)JIT";3345
3346std::vector<IValue> args1{at::randn({1}), 10};3347std::vector<IValue> args2{at::randn({3, 3, 3}), 10};3348
3349testStaticRuntime(src, args1, args2);3350}
3351
3352TEST(StaticRuntime, NestedLoops) {3353const auto src = R"JIT(3354def forward(self, x, loop_max: int):
3355y = x.clone()
3356even: List[int] = []
3357odd: List[int] = []
3358
3359for i in range(loop_max):
3360if i % 2:
3361odd.append(i)
3362else:
3363even.append(i)
3364
3365for j in range(i):
3366y += 1
3367
3368return y, even, odd
3369)JIT";3370
3371std::vector<IValue> args1{at::randn({1}), 10};3372std::vector<IValue> args2{at::randn({3, 3, 3}), 10};3373
3374testStaticRuntime(src, args1, args2);3375}
3376
3377TEST(StaticRuntime, TupleIndex) {3378const auto src = R"JIT(3379def forward(self, idx: int, tup: Tuple[int, int]):
3380a = tup[idx]
3381return a * a
3382)JIT";3383const auto tuple = c10::ivalue::Tuple::create({1, 2});3384testStaticRuntime(src, {1, tuple}, {-1, tuple});3385
3386torch::jit::Module mod("module");3387mod.define(src);3388StaticModule smod(mod);3389EXPECT_THROW(smod({100, tuple}), std::out_of_range);3390}
3391
3392TEST(StaticRuntime, RaiseException) {3393const auto src = R"IR(3394graph(%str: str):
3395%none: NoneType = prim::Constant()
3396prim::RaiseException(%str, %none)
3397return (%none)
3398)IR";3399auto graph = getGraphFromIR(src);3400StaticModule smod(graph);3401const auto msg = "exception message";3402EXPECT_THROW(3403{3404try {3405smod({msg});3406} catch (const std::runtime_error& e) {3407EXPECT_STREQ(msg, e.what());3408throw;3409}3410},3411std::runtime_error);3412}
3413
3414TEST(StaticRuntime, Uninitialized) {3415const auto src = R"IR(3416graph():
3417%0: int = prim::Uninitialized()
3418return (%0)
3419)IR";3420auto graph = getGraphFromIR(src);3421StaticModule smod(graph);3422const auto ret = smod({});3423// If a and b are both uninitialized, then a != b. So just check that the type3424// is Any3425EXPECT_EQ(ret.type()->kind(), c10::TypeKind::AnyType);3426}
3427
3428TEST(StaticRuntime, Format) {3429const auto src = R"JIT(3430def forward(self, arg1: int, arg2: Tensor, arg3: str):
3431a = "arg1: {}, arg2: {}, arg3: {}".format(arg1, arg2, arg3)
3432return a[::]
3433)JIT";3434testStaticRuntime(src, {1, at::randn({3}), "str"});3435}
3436
3437TEST(StaticRuntime, Device) {3438const auto src = R"JIT(3439def forward(self, x):
3440return x.device, x.device
3441)JIT";3442testStaticRuntime(src, {at::tensor({1})});3443}
3444
3445TEST(StaticRuntime, Dtype) {3446const auto src = R"JIT(3447def forward(self, x, y):
3448return x.dtype, y.dtype
3449)JIT";3450testStaticRuntime(3451src, {at::tensor({1}, at::kLong), at::tensor({1}, at::kFloat)});3452}
3453
3454TEST(StaticRuntime, Dim) {3455const auto src = R"JIT(3456def forward(self, x, y):
3457return x.dim(), y.dim()
3458)JIT";3459testStaticRuntime(src, {at::randn({2, 2}), at::randn({1})});3460}
3461
3462TEST(StaticRuntime, Not) {3463const auto src = R"JIT(3464def forward(self, x: bool, y: bool):
3465return not x, not y
3466)JIT";3467testStaticRuntime(src, {true, false});3468}
3469
3470TEST(StaticRuntime, Bool) {3471const auto src = R"JIT(3472def forward(self, x: Tensor, y: int, z: float):
3473return bool(x), bool(y), bool(z)
3474)JIT";3475testStaticRuntime(src, {at::randn({1}), 0, 1.151}, {at::zeros({1}), 1, 0.0});3476}
3477
3478TEST(StaticRuntime, IsCuda) {3479const auto src = R"JIT(3480def forward(self, x: Tensor, y: Tensor):
3481return x.is_cuda, y.is_cuda
3482)JIT";3483testStaticRuntime(src, {at::randn({1}), at::randn({1})});3484}
3485
3486TEST(StaticRuntime, ToList) {3487const auto src = R"JIT(3488graph(%x: Tensor):
3489%type: int = prim::Constant[value=1]()
3490%dim: int = aten::dim(%x)
3491%ret: float[] = prim::tolist(%x, %dim, %type)
3492return (%ret)
3493)JIT";3494testStaticRuntime(src, {at::randn({2, 2})});3495}
3496
3497TEST(StaticRuntime, IfThenElse) {3498const auto src = R"IR(3499graph(%cond: bool, %a: Tensor, %b: Tensor):
3500%none: NoneType = prim::Constant()
3501%c: Tensor = prim::IfThenElse(%cond, %a, %b)
3502%d: Tensor = aten::clone(%c, %none)
3503return (%d)
3504)IR";3505
3506std::vector<IValue> args1{true, at::randn({1}), at::randn({1})};3507std::vector<IValue> args2{false, at::randn({1}), at::randn({1})};3508
3509testStaticRuntime(src, args1);3510testStaticRuntime(src, args2);3511}
3512
3513TEST(StaticRuntime, EmptyIfBlock) {3514const auto src =3515R"JIT(3516def forward(self, cond: bool, a: Tensor, b: Tensor):
3517l = []
3518if cond:
3519l.append((a + b).clone())
3520return l
3521)JIT";3522
3523testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})});3524testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})});3525}
3526
3527TEST(StaticRuntime, EmptyNestedIfBlock) {3528const auto src =3529R"JIT(3530def forward(self, cond: bool, a: Tensor, b: Tensor):
3531l = []
3532if cond:
3533if cond:
3534l.append((a + b).clone())
3535return l
3536)JIT";3537
3538testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})});3539testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})});3540}
3541
3542TEST(StaticRuntime, StackEmpty) {3543const auto src = R"JIT(3544def forward(self):
3545x = torch.stack([])
3546return x
3547)JIT";3548
3549torch::jit::Module mod("mod");3550mod.define(src);3551
3552torch::jit::StaticModule smod(mod);3553EXPECT_THROW(smod({}), c10::Error);3554}
3555
3556TEST(StaticRuntime, ConcatEmpty) {3557const auto src = R"JIT(3558def forward(self):
3559x = torch.concat([])
3560return x
3561)JIT";3562
3563torch::jit::Module mod("mod");3564mod.define(src);3565
3566torch::jit::StaticModule smod(mod);3567EXPECT_THROW(smod({}), c10::Error);3568}
3569
3570TEST(StaticRuntime, IntImplicit) {3571const auto src = R"IR(3572graph(%a: Tensor):
3573%y: int = aten::IntImplicit(%a)
3574return (%y)
3575)IR";3576testStaticRuntime(src, {at::tensor({1}, at::kInt).squeeze()});3577}
3578
3579TEST(StaticRuntime, IntImplicit_ThrowOnBadInputs) {3580const auto src = R"IR(3581graph(%a: Tensor):
3582%y: int = aten::IntImplicit(%a)
3583return (%y)
3584)IR";3585auto graph = getGraphFromIR(src);3586torch::jit::StaticModule smod(graph);3587// Not 0D tensor3588EXPECT_THROW(smod({at::tensor({1, 2}, at::kInt)}), std::runtime_error);3589// Wrong dtype3590EXPECT_THROW(3591smod({at::tensor({1}, at::kFloat).squeeze()}), std::runtime_error);3592}
3593
3594TEST(StaticRuntime, Select) {3595const auto src = R"IR(3596graph(%a: Tensor, %dim: int, %index: int):
3597%none: NoneType = prim::Constant()
3598%b: Tensor = aten::select(%a, %dim, %index)
3599%c: Tensor = aten::clone(%b, %none)
3600return (%c)
3601)IR";3602testStaticRuntime(src, {at::randn({2, 2}), 0, 1});3603}
3604
3605TEST(StaticRuntime, ReshapeAs) {3606const auto src = R"JIT(3607def forward(self, a, b):
3608return a.reshape_as(b).clone()
3609)JIT";3610testStaticRuntime(src, {at::randn({2, 2}), at::randn({4})});3611}
3612
3613TEST(StaticRuntime, MoveCtor) {3614auto mod = getDeepAndWideSciptModel();3615std::vector<IValue> args{3616at::randn({1, 1, 32}), at::randn({1, 1, 32}), at::randn({1, 50})};3617
3618torch::jit::StaticModule smod(mod);3619
3620torch::jit::StaticRuntime runtime(smod);3621auto expected = runtime(args);3622
3623torch::jit::StaticRuntime new_runtime(std::move(runtime));3624auto actual = new_runtime(args);3625compareResults(expected, actual);3626}
3627
3628TEST(StaticRuntime, SingleBlockIfReturnList) {3629const auto src = R"JIT(3630def forward(self, a, b, cond: bool):
3631lst = []
3632if cond:
3633lst.append(a + b)
3634return lst
3635)JIT";3636std::vector<IValue> args1{at::randn({1}), at::randn({1}), true};3637std::vector<IValue> args2{at::randn({42, 42}), at::randn({42, 42}), false};3638testStaticRuntime(src, args1, args2);3639}
3640
3641TEST(StaticRuntime, NestedBlockIfReturnList) {3642const auto src = R"JIT(3643def forward(self, a, b, cond1: bool, cond2: bool):
3644if cond1:
3645lst = []
3646if cond2:
3647lst.append(a + b)
3648lst.append(a * b)
3649return lst
3650return []
3651)JIT";3652std::vector<IValue> args1{at::randn({1}), at::randn({1}), true, true};3653std::vector<IValue> args2{3654at::randn({42, 42}), at::randn({42, 42}), true, false};3655testStaticRuntime(src, args1, args2);3656}
3657
3658TEST(StaticRuntime, ClampNaNToNum) {3659const auto src1 = R"JIT(3660def forward(self, a):
3661return torch.clamp(a, min=1.0, max=2.0).nan_to_num().clone()
3662)JIT";3663
3664const auto src2 = R"JIT(3665def forward(self, a, nan: float):
3666return torch.clamp(a, min=-1.0, max=2.0).nan_to_num(nan=nan).clone()
3667)JIT";3668
3669const auto src3 = R"JIT(3670def forward(self, a):
3671return torch.clamp(a, min=1.0, max=-1.0).nan_to_num().clone()
3672)JIT";3673
3674auto a = at::tensor({3675std::numeric_limits<float>::quiet_NaN(),3676std::numeric_limits<float>::infinity(),3677-std::numeric_limits<float>::infinity(),36780.0f,36793.0f3680});3681auto b = a.repeat({10, 5});3682
3683// Have to use_allclose even though all NaNs will be replaced - testStaticRuntime3684// also checks inputs at the end to make sure they're not changed3685testStaticRuntime(src1, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);3686testStaticRuntime(src1, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);3687
3688testStaticRuntime(src2, {a, 42.0}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);3689testStaticRuntime(src2, {a, 2.0}, {b, 1.0}, /*use_allclose=*/true, /*use_equalnan=*/true);3690
3691testStaticRuntime(src3, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);3692testStaticRuntime(src3, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);3693
3694// Non-NNC path3695testStaticRuntime(src1, {a.to(at::kDouble)}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);3696testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true);3697}
3698
3699TEST(StaticRuntime, IfReturningTuple) {3700const auto src = R"JIT(3701def forward(self, x, y, cond: bool, idx: int):
3702if cond:
3703tup = (x, y)
3704else:
3705tup = (x, x)
3706return tup[idx]
3707)JIT";3708
3709std::vector<IValue> args{at::randn({3}), at::randn({3}), true, 0};3710testStaticRuntime(src, args);3711}
3712