pytorch

Форк
0
/
test_static_runtime.cc 
3711 строк · 111.9 Кб
1
#include <ATen/core/dispatch/OperatorOptions.h>
2
#include <c10/core/ScalarType.h>
3
#include <gtest/gtest.h>
4
#include <torch/csrc/jit/ir/alias_analysis.h>
5
#include <torch/csrc/jit/ir/irparser.h>
6
#include <torch/csrc/jit/runtime/static/ProcessedNodeInputs.h>
7
#include <torch/csrc/jit/runtime/static/impl.h>
8
#include <torch/csrc/jit/runtime/static/passes.h>
9
#include <torch/csrc/jit/testing/file_check.h>
10
#include <stdexcept>
11

12
#include "deep_wide_pt.h"
13
#include "test_utils.h"
14

15
using namespace caffe2;
16
using namespace torch;
17
using namespace torch::jit;
18
using namespace torch::jit::test;
19
using c10::IValue;
20

21
/*
22
 When adding a test for an operator implemented in static runtime, there are
23
 several things that you need to pay attention to:
24

25
 1) if the op is an out variant, in the test script of the op,
26
 instead of:
27
    def forward(self, input):
28
      return myop(input)
29

30
  do:
31
    def forward(self, input):
32
      return myop(input).clone()
33

34
 This makes sure that the output of myop is managed by the memory planner and
35
 exercise the code path in the op impl that otherwise doesn't get exercised. The
36
 output of the model is not managed by the memory planner, because it needs to
37
 be returned to the client.
38

39
 2) The memory planner rounds up the size of each Tensor's storage to multiples
40
 of 64 bytes (alignment requirement on AVX512). Make sure the sizes of the input
41
 tensors in args2 are big enough to trigger resizing.
42

43
 3) for view ops such as aten::reshape or aten::to, if you want it to be
44
 replaced by the copy version with the ReplaceWithCopy pass in passes.h, you
45
 also want to make sure its output is not returned as the model output. The
46
 reason is that ReplaceWithCopy only replaces the op whose output is not an
47
 alias of the model output.
48
*/
49

50
C10_DECLARE_bool(static_runtime_enable_fast_math);
51

52
TEST(StaticRuntime, UnaryOps) {
53
  const auto aten_sum = R"JIT(
54
    def forward(self, input):
55
        return torch.sum(input).clone()
56
  )JIT";
57

58
  const auto aten_sum_0 = R"JIT(
59
    def forward(self, input):
60
        return torch.sum(input, 0).clone()
61
  )JIT";
62

63
  const auto aten_sum_1 = R"JIT(
64
    def forward(self, input):
65
        return torch.sum(input, 1).clone()
66
  )JIT";
67

68
  const auto aten_sum_0_true = R"JIT(
69
    def forward(self, input):
70
        return torch.sum(input, 0, True).clone()
71
  )JIT";
72

73
  const auto aten_sum_1_true = R"JIT(
74
    def forward(self, input):
75
        return torch.sum(input, 1, True).clone()
76
  )JIT";
77

78
  auto a = at::randn({2, 3});
79
  auto b = at::randn({3, 3, 6});
80

81
  std::vector<IValue> args{a}, args2{b};
82

83
  // sum
84
  testStaticRuntime(aten_sum, args);
85
  testStaticRuntime(aten_sum_0, args);
86
  testStaticRuntime(aten_sum_1, args);
87
  testStaticRuntime(aten_sum_0_true, args);
88
  testStaticRuntime(aten_sum_1_true, args);
89

90
  testStaticRuntime(aten_sum, args, args2, false, false, false);
91
  testStaticRuntime(aten_sum_0, args, args2);
92
  testStaticRuntime(aten_sum_1, args, args2);
93
  testStaticRuntime(aten_sum_0_true, args, args2);
94
  testStaticRuntime(aten_sum_1_true, args, args2);
95
}
96

97
TEST(StaticRuntime, Max) {
98
  auto src_max_reduce = R"JIT(
99
    def forward(self, input):
100
        return torch.max(input).clone()
101
  )JIT";
102

103
  auto src_max_dim = R"JIT(
104
    def forward(self, input, dim: int):
105
        values, indices = torch.max(input, dim)
106
        return values.clone(), indices.clone()
107
  )JIT";
108

109
  auto src_max_dim_keepdim = R"JIT(
110
    def forward(self, input, dim: int):
111
        values, indices = torch.max(input, dim, keepdim=True)
112
        return values.clone(), indices.clone()
113
  )JIT";
114

115
  auto src_max_pointwise = R"JIT(
116
    def forward(self, input, other):
117
        return torch.max(input, other).clone()
118
  )JIT";
119

120
  auto input = at::randn({2, 3, 2});
121
  auto input_other = at::randn({2, 3, 2});
122
  auto large_input = at::randn({8, 9, 10});
123
  auto large_input_other = at::randn({8, 9, 10});
124

125
  testStaticRuntime(src_max_reduce, {input});
126
  testStaticRuntime(src_max_dim, {input, 1});
127
  testStaticRuntime(src_max_dim, {input, 1}, {large_input, 0});
128
  testStaticRuntime(src_max_dim_keepdim, {input, 0});
129
  testStaticRuntime(src_max_dim_keepdim, {input, 0}, {large_input, 2});
130
  testStaticRuntime(src_max_pointwise, {input, input_other});
131
  testStaticRuntime(src_max_pointwise, {input, input_other}, {large_input, large_input_other});
132
}
133

134
TEST(StaticRuntime, Mean) {
135
  const auto src_default = R"JIT(
136
    def forward(self, input):
137
        return torch.mean(input).clone()
138
  )JIT";
139
  const auto src_dtype = R"JIT(
140
    def forward(self, input, dtype: int):
141
        return torch.mean(input, dtype=dtype).clone()
142
  )JIT";
143
  const auto src_dim = R"JIT(
144
    def forward(self, input, dim: List[int]):
145
        return torch.mean(input, dim).clone()
146
  )JIT";
147
  const auto src_dim_keepdim = R"JIT(
148
    def forward(self, input, dim: List[int]):
149
        return torch.mean(input, dim, keepdim=True).clone()
150
  )JIT";
151
  const auto src_dim_dtype = R"JIT(
152
    def forward(self, input, dim: List[int], dtype: int):
153
        return torch.mean(input, dim, dtype=dtype).clone()
154
  )JIT";
155

156
  auto input = at::randn({2, 3, 2});
157
  auto large_input = at::randn({8, 7, 6, 8});
158

159
  std::vector<IValue> args_default = {input};
160
  std::vector<IValue> args_dtype = {input, torch::kFloat};
161
  std::vector<IValue> args_dim = {input, c10::List<int64_t>{0, 2}};
162
  std::vector<IValue> args_dim_keepdim = {input, c10::List<int64_t>{1, 2}};
163
  std::vector<IValue> args_dim_dtype = {input, c10::List<int64_t>{0, 1}, torch::kBFloat16};
164

165
  testStaticRuntime(src_default, args_default);
166
  testStaticRuntime(src_dtype, args_dtype);
167
  testStaticRuntime(src_dim, args_dim);
168
  testStaticRuntime(src_dim_keepdim, args_dim_keepdim);
169
  testStaticRuntime(src_dim_dtype, args_dim_dtype);
170

171
  std::vector<IValue> large_args_dim = {large_input, c10::List<int64_t>{0, 3}};
172
  std::vector<IValue> large_args_dim_keepdim = {large_input, c10::List<int64_t>{1, 2}};
173
  std::vector<IValue> large_args_dim_dtype = {large_input, c10::List<int64_t>{1, 3}, torch::kBFloat16};
174

175
  testStaticRuntime(src_dim, args_dim, large_args_dim);
176
  testStaticRuntime(src_dim_keepdim, args_dim_keepdim, large_args_dim_keepdim);
177
  testStaticRuntime(src_dim_dtype, args_dim_dtype, large_args_dim_dtype);
178
}
179

180
TEST(StaticRuntime, Sigmoid) {
181
  const auto sigmoid_script = R"JIT(
182
    def forward(self, inp: Tensor):
183
        b = torch.sigmoid(inp).clone()
184
        return (b)
185
  )JIT";
186
  auto a = at::randn({2, 3});
187
  auto b = at::randn({4, 3, 2});
188

189
  std::vector<IValue> args{a}, args2{b};
190

191
  testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true);
192
  testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true);
193

194
  FLAGS_static_runtime_enable_fast_math = false;
195
  testStaticRuntime(sigmoid_script, args, /*args2=*/{}, /*use_allclose=*/true);
196
  testStaticRuntime(sigmoid_script, args, {args2}, /*use_allclose=*/true);
197
  FLAGS_static_runtime_enable_fast_math = true;
198
}
199

200
TEST(StaticRuntime, Clone) {
201
  /*
202
  Clone called two times to trigger memory planner for output of first clone.
203
  The output of last op(second clone) is not managed by memory planner since it
204
  needs to be returned to the client and cannot be reused by planner.
205
  */
206
  const auto clone_script_0 = R"JIT(
207
    def forward(self, input):
208
        a = torch.clone(input).clone()
209
        return (a * a)
210
  )JIT";
211

212
  // Case: clone with different set of memory_formats
213
  const auto clone_script_1 = R"JIT(
214
    def forward(self, input: Tensor, memory_format: int):
215
        a = torch.clone(input, memory_format=memory_format).clone()
216
        return (a * a)
217
  )JIT";
218

219
  /*
220
  Case: input stride set to 0 (due to expand op)
221
  calls native clone instead of out variant
222
  */
223
  const auto clone_script_2 = R"JIT(
224
    def forward(self, input: Tensor, other:Tensor):
225
        a = input.expand_as(other)
226
        return a.clone().clone()
227
  )JIT";
228

229
  /*
230
  Case: testing the case of sliced tensor for
231
  testing non-contiguous tensor storage
232
  */
233
  const auto clone_script_3 = R"JIT(
234
    def forward(self, input: Tensor):
235
        a = input[:, 0:10:2]
236
        return a.clone().clone()
237
  )JIT";
238

239
  auto a = at::randn({2, 3});
240
  auto b = at::randn({3, 2}).as_strided({3, 2}, {1, 3});
241
  auto b_larger = at::randn({30, 20}).as_strided({30, 20}, {1, 3});
242
  auto c = at::randn({1, 20, 13, 8});
243
  auto d = at::randn({1, 0, 3, 4});
244
  auto e = at::randn({2, 1});
245
  auto f = at::randn({2, 10});
246
  auto g = at::randn({3, 20});
247
  std::vector<IValue> args_0{b, c10::MemoryFormat::Contiguous};
248
  std::vector<IValue> args_1{b_larger, c10::MemoryFormat::Preserve};
249
  std::vector<IValue> args_2{c, c10::MemoryFormat::ChannelsLast};
250
  std::vector<IValue> args_3{d, c10::MemoryFormat::ChannelsLast};
251
  std::vector<IValue> args_4{e,a};
252
  std::vector<IValue> args_5{e,f};
253

254
  testStaticRuntime(clone_script_0, {a});
255
  testStaticRuntime(clone_script_0, {a}, {b_larger});
256

257
  testStaticRuntime(clone_script_1, args_0);
258
  testStaticRuntime(clone_script_1, args_1);
259
  testStaticRuntime(clone_script_1, args_2);
260
  testStaticRuntime(clone_script_1, args_3);
261
  testStaticRuntime(clone_script_1, args_0, args_1);
262
  testStaticRuntime(clone_script_1, args_3, args_2);
263

264
  testStaticRuntime(clone_script_2, args_4);
265
  testStaticRuntime(clone_script_2, args_4, args_5);
266

267
  testStaticRuntime(clone_script_3, {f});
268
  testStaticRuntime(clone_script_3, {f}, {g});
269
}
270

271
TEST(StaticRuntime, Clamp) {
272
  const auto clamp_script_1 = R"JIT(
273
    def forward(self, inp: Tensor, min: int, max: int):
274
        a = torch.clamp(inp, min, max).clone()
275
        return (a)
276
  )JIT";
277

278
  const auto clamp_script_2 = R"JIT(
279
    def forward(self, inp: Tensor, min: Tensor, max: Tensor):
280
        a = torch.clamp(inp, min, max).clone()
281
        return (a)
282
  )JIT";
283
  auto a = at::randn({2, 3});
284
  auto max_t = at::full_like(a, 1);
285
  auto min_t = at::full_like(a, -1);
286

287
  auto b = at::randn({4, 3, 2});
288
  auto max_t1 = at::full_like(b, 1);
289
  auto min_t1 = at::full_like(b, -1);
290

291
  testStaticRuntime(clamp_script_1, {a, -1, 1});
292
  testStaticRuntime(clamp_script_2, {a, min_t, max_t});
293

294
  testStaticRuntime(clamp_script_1, {a, -1, 1}, {b, -1, 1});
295
  testStaticRuntime(clamp_script_2, {a, min_t, max_t}, {b, max_t1, min_t1});
296
}
297

298
TEST(StaticRuntime, ClampMinOnly) {
299
  const auto src = R"JIT(
300
    def forward(self, inp: Tensor, min: float):
301
        a = torch.clamp(inp, min, None).clone()
302
        return (a)
303
  )JIT";
304
  auto a = at::randn({2, 3});
305
  auto b = at::randn({4, 3, 2});
306
  testStaticRuntime(src, {a, 0.5});
307
  testStaticRuntime(src, {a, 0.5}, {b, 0.25});
308
}
309

310
TEST(StaticRuntime, ClampMaxOnly) {
311
  const auto src = R"JIT(
312
    def forward(self, inp: Tensor, max: float):
313
        a = torch.clamp(inp, None, max).clone()
314
        return (a)
315
  )JIT";
316
  auto a = at::randn({2, 3});
317
  auto b = at::randn({4, 3, 2});
318
  testStaticRuntime(src, {a, 0.5});
319
  testStaticRuntime(src, {a, 0.5}, {b, 0.25});
320
}
321

322
TEST(StaticRuntime, ClampIntTensor) {
323
  const auto src = R"JIT(
324
    def forward(self, inp: Tensor, min: float, max: float):
325
        a = torch.clamp(inp, min, max).clone()
326
        return (a)
327
  )JIT";
328
  auto a = at::randint(0, 20, {2, 3}, at::kFloat);
329
  auto b = at::randint(0, 20, {4, 3, 2}, at::kFloat);
330
  auto min = 5.0f;
331
  auto max = 5.0f;
332
  testStaticRuntime(src, {a, min, max});
333
  testStaticRuntime(src, {a, min, max}, {b, min, max});
334
}
335

336
TEST(StaticRuntime, LenWithTuple) {
337
  const auto src = R"IR(
338
    graph(%input : int[]):
339
        %res : int = aten::len(%input)
340
        return (%res)
341
  )IR";
342

343
  testStaticRuntime(src, {c10::List<int64_t>(4)});
344
}
345

346
TEST(StaticRuntime, LenWithTensor) {
347
  const auto src = R"IR(
348
    graph(%input : Tensor):
349
        %res : int = aten::len(%input)
350
        return (%res)
351
  )IR";
352

353
  testStaticRuntime(src, {at::randn({2, 2, 2})});
354
}
355

356
TEST(StaticRuntime, LenWithStr) {
357
  const auto src = R"IR(
358
    graph(%input : str):
359
        %res : int = aten::len(%input)
360
        return (%res)
361
  )IR";
362

363
  testStaticRuntime(src, {"static_runtime"});
364
}
365

366
TEST(StaticRuntime, LenWithDict_str) {
367
  const auto script = R"JIT(
368
    def forward(self, input: Dict[str, str]):
369
        return len(input)
370
  )JIT";
371

372
  c10::Dict<std::string, std::string> dict;
373
  dict.insert("abc", "123");
374
  dict.insert("def", "456");
375
  testStaticRuntime(script, {dict});
376
}
377

378
TEST(StaticRuntime, LenWithDict_int) {
379
  const auto script = R"JIT(
380
    def forward(self, input: Dict[int, int]):
381
        return len(input)
382
  )JIT";
383

384
  c10::Dict<int64_t, int64_t> dict;
385
  dict.insert(0, 1);
386
  dict.insert(2, 3);
387
  testStaticRuntime(script, {dict});
388
}
389

390
TEST(StaticRuntime, LenWithDict_bool) {
391
  const auto script = R"JIT(
392
    def forward(self, input: Dict[bool, bool]):
393
        return len(input)
394
  )JIT";
395

396
  c10::Dict<bool, bool> dict;
397
  dict.insert(true, false);
398
  dict.insert(false, true);
399
  testStaticRuntime(script, {dict});
400
}
401

402
TEST(StaticRuntime, LenWithDict_float) {
403
  const auto script = R"JIT(
404
    def forward(self, input: Dict[float, float]):
405
        return len(input)
406
  )JIT";
407

408
  c10::Dict<double, double> dict;
409
  dict.insert(0.1, 0.9);
410
  dict.insert(0.8, 0.18);
411
  testStaticRuntime(script, {dict});
412
}
413

414
TEST(StaticRuntime, LenWithDict_complex) {
415
  const auto script = R"JIT(
416
    def forward(self, input: Dict[complex, complex]):
417
        return len(input)
418
  )JIT";
419

420
  c10::Dict<c10::complex<double>, c10::complex<double>> dict;
421
  dict.insert(0.1, 0.4);
422
  dict.insert(0.9, 0.45);
423
  testStaticRuntime(script, {dict});
424
}
425

426
TEST(StaticRuntime, LenWithDict_Tensor) {
427
  const auto script = R"JIT(
428
    def forward(self, input: Dict[Tensor, Tensor]):
429
        return len(input)
430
  )JIT";
431

432
  c10::Dict<at::Tensor, at::Tensor> dict;
433
  dict.insert(at::randn({1, 2}), at::randn({1, 2}));
434
  dict.insert(at::randn({1, 2}), at::randn({1, 2}));
435
  testStaticRuntime(script, {dict});
436
}
437

438
TEST(StaticRuntime, Logit) {
439
  // no nnc
440
  const auto logit_script_1 = R"JIT(
441
    def forward(self, inp: Tensor):
442
        a = torch.logit(inp).clone()
443
        return (a)
444
  )JIT";
445

446
  // with nnc
447
  const auto logit_script_2 = R"JIT(
448
    def forward(self, inp: Tensor):
449
        a = torch.logit(inp, 1e-6).clone()
450
        return (a)
451
  )JIT";
452

453
  // no nnc
454
  const auto logit_script_3 = R"JIT(
455
    def forward(self, inp: Tensor, eps: float):
456
        a = torch.logit(inp, eps).clone()
457
        return (a)
458
  )JIT";
459
  auto a = at::ones({2, 3});
460
  double b = 1e-6;
461
  std::vector<IValue> args_1{a};
462
  std::vector<IValue> args_2({a, b});
463

464
  auto c = at::ones({4, 3, 2});
465

466
  // logit
467
  testStaticRuntime(logit_script_1, args_1);
468
  testStaticRuntime(logit_script_2, args_1);
469
  testStaticRuntime(logit_script_3, args_2);
470

471
  testStaticRuntime(logit_script_1, args_1, {c});
472
  testStaticRuntime(logit_script_2, args_1, {c});
473
  testStaticRuntime(logit_script_3, args_2, {c, b});
474
}
475

476
TEST(StaticRuntime, EmbeddingBag) {
477
  const std::string embedding_bag_default = R"JIT(
478
    def forward(self, a: Tensor, b: Tensor, c: Tensor):
479
        x, y, z, _ = torch.embedding_bag(a, b, c)
480
        return (x.clone(), y.clone(), z.clone(), _.clone())
481
  )JIT";
482

483
  const std::string embedding_bag_mean = R"JIT(
484
    def forward(self, a: Tensor, b: Tensor, c: Tensor):
485
        x, y, z, _ = torch.embedding_bag(a, b, c, False, 1)
486
        return (x.clone(), y.clone(), z.clone(), _.clone())
487
  )JIT";
488

489
  const std::string embedding_bag_max = R"JIT(
490
    def forward(self, a: Tensor, b: Tensor, c: Tensor):
491
        x, y, z, _ = torch.embedding_bag(a, b, c, False, 2)
492
        return (x.clone(), y.clone(), z.clone(), _.clone())
493
  )JIT";
494

495
  const std::string embedding_bag_sum_last_offset = R"JIT(
496
    def forward(self, a: Tensor, b: Tensor, c: Tensor):
497
        x, y, z, _ = torch.embedding_bag(a, b, c, False, 0, False, None, True)
498
        return (x.clone(), y.clone(), z.clone(), _.clone())
499
  )JIT";
500

501
  const std::string embedding_bag_mean_last_offset = R"JIT(
502
    def forward(self, a: Tensor, b: Tensor, c: Tensor):
503
        x, y, z, _ = torch.embedding_bag(a, b, c, False, 1, False, None, True)
504
        return (x.clone(), y.clone(), z.clone(), _.clone())
505
  )JIT";
506

507
  const std::string embedding_bag_max_last_offset = R"JIT(
508
    def forward(self, a: Tensor, b: Tensor, c: Tensor):
509
        x, y, z, _ = torch.embedding_bag(a, b, c, False, 2, False, None, True)
510
        return (x.clone(), y.clone(), z.clone(), _.clone())
511
  )JIT";
512

513
  at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);
514
  at::Tensor input = torch::tensor({0, 1, 0, 2});
515
  at::Tensor offset = torch::tensor({0, 2, 4});
516
  std::vector<IValue> args{weight, input, offset};
517
  testStaticRuntime(embedding_bag_default, args);
518
  testStaticRuntime(embedding_bag_mean, args);
519
  testStaticRuntime(embedding_bag_max, args);
520
  testStaticRuntime(embedding_bag_sum_last_offset, args);
521
  testStaticRuntime(embedding_bag_mean_last_offset, args);
522
  testStaticRuntime(embedding_bag_max_last_offset, args);
523

524
  at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);
525
  at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});
526
  at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});
527
  std::vector<IValue> args2{weight2, input2, offset2};
528
  testStaticRuntime(embedding_bag_default, args, args2);
529
  testStaticRuntime(embedding_bag_mean, args, args2);
530
  testStaticRuntime(embedding_bag_max, args, args2);
531
  testStaticRuntime(embedding_bag_sum_last_offset, args, args2);
532
  testStaticRuntime(embedding_bag_mean_last_offset, args, args2);
533
  testStaticRuntime(embedding_bag_max_last_offset, args, args2);
534
}
535

536
TEST(StaticRuntime, EmbeddingBagWithManagedOutput) {
537
  const std::string embedding_bag_managed_output = R"JIT(
538
    def forward(self, a: Tensor, b: Tensor, c: Tensor):
539
        # The outputs of embedding_bag become an intermediate tensors
540
        # since they are not directly returned from the graph.
541
        x, y, z, _ = torch.embedding_bag(a, b, c)
542
        return x + x
543
  )JIT";
544

545
  at::Tensor weight = torch::randn({3, 8}, at::ScalarType::Float);
546
  at::Tensor input = torch::tensor({0, 1, 0, 2});
547
  at::Tensor offset = torch::tensor({0, 2});
548
  std::vector<IValue> args{weight, input, offset};
549

550
  at::Tensor weight2 = torch::randn({6, 8}, at::ScalarType::Float);
551
  at::Tensor input2 = torch::tensor({0, 1, 0, 2, 3, 4});
552
  at::Tensor offset2 = torch::tensor({0, 2, 4, 5});
553
  std::vector<IValue> args2{weight2, input2, offset2};
554

555
  testStaticRuntime(embedding_bag_managed_output, args);
556
  testStaticRuntime(embedding_bag_managed_output, args, args2);
557
}
558

559
TEST(StaticRuntime, EmbeddingBagWithExtraneousOutput) {
560
  const std::string embedding_bag_default_ir = R"IR(
561
    graph(%weight, %indices, %offsets):
562
        %scale_grad_by_freq : bool = prim::Constant[value=0]()
563
        %mode : int = prim::Constant[value=0]()
564
        %sparse : bool = prim::Constant[value=0]()
565
        %per_sample_weights : NoneType = prim::Constant()
566
        %include_last_offset : bool = prim::Constant[value=0]()
567
        %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
568
        %none : NoneType = prim::Constant()
569
        %res : Tensor = aten::clone(%y0, %none)
570
        return (%res)
571
  )IR";
572
  auto graph = getGraphFromIR(embedding_bag_default_ir);
573
  RemoveUnnecessaryOutputs(graph);
574
  torch::jit::testing::FileCheck()
575
      .check("static_runtime::embedding_bag")
576
      ->run(*graph);
577

578
  const std::string embedding_bag_mean_ir = R"IR(
579
    graph(%weight, %indices, %offsets):
580
        %scale_grad_by_freq : bool = prim::Constant[value=0]()
581
        %mode : int = prim::Constant[value=1]()
582
        %sparse : bool = prim::Constant[value=0]()
583
        %per_sample_weights : NoneType = prim::Constant()
584
        %include_last_offset : bool = prim::Constant[value=0]()
585
        %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
586
        %none : NoneType = prim::Constant()
587
        %res : Tensor = aten::clone(%y0, %none)
588
        return (%res)
589
  )IR";
590
  graph = getGraphFromIR(embedding_bag_mean_ir);
591
  RemoveUnnecessaryOutputs(graph);
592
  torch::jit::testing::FileCheck()
593
      .check("static_runtime::embedding_bag")
594
      ->run(*graph);
595

596
  const std::string embedding_bag_max_last_offset_ir = R"IR(
597
    graph(%weight, %indices, %offsets):
598
        %scale_grad_by_freq : bool = prim::Constant[value=0]()
599
        %mode : int = prim::Constant[value=2]()
600
        %sparse : bool = prim::Constant[value=0]()
601
        %per_sample_weights : NoneType = prim::Constant()
602
        %include_last_offset : bool = prim::Constant[value=1]()
603
        %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
604
        %none : NoneType = prim::Constant()
605
        %res : Tensor = aten::clone(%y0, %none)
606
        return (%res)
607
  )IR";
608
  graph = getGraphFromIR(embedding_bag_max_last_offset_ir);
609
  RemoveUnnecessaryOutputs(graph);
610
  torch::jit::testing::FileCheck()
611
      .check("static_runtime::embedding_bag")
612
      ->run(*graph);
613

614
  const std::string embedding_bag_normal_ir = R"IR(
615
    graph(%weight, %indices, %offsets):
616
        %scale_grad_by_freq : bool = prim::Constant[value=0]()
617
        %mode : int = prim::Constant[value=0]()
618
        %sparse : bool = prim::Constant[value=0]()
619
        %per_sample_weights : NoneType = prim::Constant()
620
        %include_last_offset : bool = prim::Constant[value=0]()
621
        %y0 : Tensor, %y1 : Tensor, %y2 : Tensor, %y3 : Tensor = aten::embedding_bag(%weight, %indices, %offsets, %scale_grad_by_freq, %mode, %sparse, %per_sample_weights, %include_last_offset)
622
        %none : NoneType = prim::Constant()
623
        %res0 : Tensor = aten::clone(%y0, %none)
624
        %res1 : Tensor = aten::clone(%y1, %none)
625
        %res2 : Tensor = aten::clone(%y2, %none)
626
        %res3 : Tensor = aten::clone(%y3, %none)
627
        return (%res0, %res1, %res2, %res3)
628
  )IR";
629
  graph = getGraphFromIR(embedding_bag_normal_ir);
630
  RemoveUnnecessaryOutputs(graph);
631
  torch::jit::testing::FileCheck()
632
      .check_not("static_runtime::embedding_bag")
633
      ->run(*graph);
634

635
  at::Tensor weight = torch::randn({3, 11}, at::ScalarType::Float);
636
  at::Tensor input = torch::tensor({0, 1, 0, 2});
637
  at::Tensor offset = torch::tensor({0, 2, 4});
638
  std::vector<IValue> args{weight, input, offset};
639
  testStaticRuntime(embedding_bag_default_ir, args);
640
  testStaticRuntime(embedding_bag_mean_ir, args);
641
  testStaticRuntime(embedding_bag_max_last_offset_ir, args);
642

643
  at::Tensor weight2 = torch::randn({10, 11}, at::ScalarType::Float);
644
  at::Tensor input2 = torch::tensor({0, 1, 0, 2, 1});
645
  at::Tensor offset2 = torch::tensor({0, 1, 2, 3, 4, 5});
646
  std::vector<IValue> args2{weight2, input2, offset2};
647
  testStaticRuntime(embedding_bag_default_ir, args, args2);
648
  testStaticRuntime(embedding_bag_mean_ir, args, args2);
649
  testStaticRuntime(embedding_bag_max_last_offset_ir, args, args2);
650
}
651

652
TEST(StaticRuntime, EmbeddingBagWithMixedInt32Int64Input) {
653
  const std::string embedding_bag_default = R"JIT(
654
    def forward(self, a: Tensor, b: Tensor, c: Tensor):
655
        x, y, z, _ = torch.embedding_bag(a, b, c)
656
        return (x.clone(), y.clone(), z.clone(), _.clone())
657
  )JIT";
658
  auto weight = torch::randn({3, 11}, at::ScalarType::Float);
659
  auto input = torch::tensor({0, 1, 0, 2}, at::ScalarType::Long);
660
  auto offset = torch::tensor({0, 2, 4}, at::ScalarType::Int);
661
  std::vector<IValue> args{weight, input, offset};
662
  testStaticRuntime(embedding_bag_default, args);
663
}
664

665
TEST(StaticRuntime, LayerNorm) {
666
  const std::string layer_norm_with_weights = R"JIT(
667
    def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor):
668
        return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone()
669
  )JIT";
670

671
  const std::string layer_norm_without_weights = R"JIT(
672
    def forward(self, input: Tensor, normalized_shape: List[int]):
673
        return torch.layer_norm(input, normalized_shape, None, None, 1e-05, False).clone()
674
  )JIT";
675

676
  const std::string layer_norm_with_noncontiguous_input = R"JIT(
677
    def forward(self, input: Tensor, normalized_shape: List[int], weight: Tensor, bias: Tensor):
678
        input = torch.transpose(input, 1, 2)
679
        return torch.layer_norm(input, normalized_shape, weight, bias, 1e-05, False).clone()
680
  )JIT";
681

682
  const auto a = torch::rand({1, 2, 2, 2});
683
  const auto b = torch::rand({3, 2, 2, 2});
684
  for (int normalized_size : {2, 3}) {
685
    std::vector<int64_t> normalized_shape(normalized_size, 2);
686
    const auto weight = torch::rand(normalized_shape);
687
    const auto bias = torch::rand(normalized_shape);
688

689
    std::vector<IValue> args{a, normalized_shape, weight, bias};
690
    std::vector<IValue> args1{b, normalized_shape, weight, bias};
691
    testStaticRuntime(layer_norm_with_weights, args);
692
    testStaticRuntime(layer_norm_with_weights, args, args1);
693
    testStaticRuntime(layer_norm_with_noncontiguous_input, args);
694

695
    args = {a, normalized_shape};
696
    testStaticRuntime(layer_norm_without_weights, args);
697
    testStaticRuntime(layer_norm_without_weights, args, {b, normalized_shape});
698
  }
699
}
700

701
TEST(StaticRuntime, Bmm) {
702
  const auto bmm_script = R"JIT(
703
    def forward(self, inp: Tensor, mat2: Tensor):
704
      return torch.bmm(inp, mat2).clone()
705
  )JIT";
706

707
  auto a = at::randn({10, 4, 5});
708
  auto b = at::randn({10, 5, 6});
709

710
  auto c = at::randn({12, 5, 6});
711
  auto d = at::randn({12, 6, 7});
712

713
  std::vector<IValue> args{a, b};
714
  std::vector<IValue> args1{c, d};
715
  testStaticRuntime(bmm_script, args);
716
  testStaticRuntime(bmm_script, args1);
717
  testStaticRuntime(bmm_script, args, args1);
718
}
719

720
TEST(StaticRuntime, Addmm) {
721
  const auto addmm_script = R"JIT(
722
    def forward(self, inp: Tensor, mat1: Tensor, mat2: Tensor, beta: float, alpha: float):
723
      return torch.addmm(inp, mat1, mat2, alpha=alpha, beta=beta).clone()
724
  )JIT";
725
  auto inp1 = at::randn({5});
726
  auto mat1 = at::randn({3, 4});
727
  auto mat2 = at::randn({4, 5});
728

729
  auto inp2 = at::randn({3, 7});
730
  auto mat3 = at::randn({3, 6});
731
  auto mat4 = at::randn({6, 7});
732

733
  std::vector<IValue> args{inp1, mat1, mat2, 1.0, 2.0};
734
  std::vector<IValue> args1{inp2, mat3, mat4, 2.0, 1.0};
735
  testStaticRuntime(addmm_script, args);
736
  testStaticRuntime(addmm_script, args1);
737
  testStaticRuntime(addmm_script, args, args1);
738
}
739

740
TEST(StaticRuntime, Abs) {
741
  const auto abs_script = R"JIT(
742
    def forward(self, a):
743
      return a.abs().clone()
744
  )JIT";
745
  auto a = at::randn({2, 3});
746
  auto b = at::randn({4, 2, 3});
747
  std::vector<IValue> args{a};
748
  std::vector<IValue> args2{b};
749
  testStaticRuntime(abs_script, args);
750
  testStaticRuntime(abs_script, args, args2);
751
}
752

753
TEST(StaticRuntime, Binary) {
754
  const auto add_script = R"JIT(
755
    def forward(self, a, b):
756
        c = a + b
757
        return (c.clone())
758
  )JIT";
759

760
  const auto add_script_ints = R"JIT(
761
    def forward(self, a: int, b: int):
762
        c = a + b
763
        d = c + 1
764
        return d
765
  )JIT";
766

767
  const auto add_list_script = R"JIT(
768
    def forward(self, a: List[int], b: List[int]):
769
        c = a + b
770
        return c[::]
771
  )JIT";
772

773
  const auto list_construct_script = R"JIT(
774
    def forward(self, a, b):
775
      return [a, b]
776
  )JIT";
777

778
  const auto list_construct_script_2 = R"JIT(
779
    def forward(self, a, b):
780
      c = a + a
781
      return [c, c]
782
  )JIT";
783

784
  const auto list_construct_script_3 = R"JIT(
785
    def forward(self, a, b):
786
      c = a + a
787
      return [c, c.flatten()]
788
  )JIT";
789

790
  const auto list_unpack_script = R"JIT(
791
    def forward(self, a, b):
792
      c = [a, b]
793
      x, y = c
794
      z = x + y
795
      return z.clone()
796
  )JIT";
797

798
  const auto list_unpack_script_2 = R"JIT(
799
    def forward(self, a, b):
800
      c = [a, b]
801
      x, y = c
802
      z = (x, y)
803
      return z
804
  )JIT";
805

806
  const auto tuple_construct_script = R"JIT(
807
    def forward(self, a, b):
808
      return (a, b)
809
  )JIT";
810

811
  const auto tuple_construct_script_2 = R"JIT(
812
    def forward(self, a, b):
813
      return (a.flatten(), b)
814
  )JIT";
815

816
  auto a = at::randn({2, 3});
817
  auto b = at::ones({2, 3});
818

819
  auto c = at::randn({4, 2, 3});
820
  auto d = at::ones({4, 2, 3});
821

822
  std::vector<IValue> args{a, b};
823

824
  testStaticRuntime(add_script, args);
825
  testStaticRuntime(add_script_ints, {1, 2});
826
  testStaticRuntime(add_script, args, {c, d});
827
  testStaticRuntime(list_construct_script, args);
828
  testStaticRuntime(list_construct_script_2, args);
829
  testStaticRuntime(list_construct_script_3, args);
830
  testStaticRuntime(list_unpack_script, args);
831
  testStaticRuntime(list_unpack_script_2, args);
832
  testStaticRuntime(tuple_construct_script, args);
833
  testStaticRuntime(tuple_construct_script_2, args);
834

835
  std::vector<IValue> list_args{
836
      c10::List<int64_t>{1, 2, 3}, c10::List<int64_t>{4, 5, 6}};
837
  testStaticRuntime(add_list_script, list_args);
838
}
839

840
TEST(StaticRuntime, MatMul) {
841
  const auto aten_matmul = R"JIT(
842
    def forward(self, a: Tensor, b: Tensor):
843
        return torch.matmul(a, b).clone()
844
  )JIT";
845

846
  // 1-D, 1-D
847
  std::vector<IValue> args{at::randn({3}), at::randn({3})};
848
  testStaticRuntime(aten_matmul, args);
849
  // 2-D, 2-D
850
  std::vector<IValue> args1 = {at::randn({3, 2}), at::randn({2, 3})};
851
  testStaticRuntime(aten_matmul, args1);
852
  // 1-D, 2-D
853
  std::vector<IValue> args2 = {at::randn({3}), at::randn({3, 5})};
854
  testStaticRuntime(aten_matmul, args2);
855
  // 2-D, 1-D
856
  std::vector<IValue> args3 = {at::randn({3, 5}), at::randn({5})};
857
  testStaticRuntime(aten_matmul, args3);
858
  // > 2-D , > 2-D
859
  std::vector<IValue> args4 = {at::randn({3, 1, 4, 5}), at::randn({2, 5, 6})};
860
  testStaticRuntime(aten_matmul, args4);
861

862
  testStaticRuntime(aten_matmul, args3, args4);
863
}
864

865
TEST(StaticRuntime, Sign) {
866
  const auto sign_tensor = R"JIT(
867
    def forward(self, input: Tensor):
868
        return torch.sign(input).clone()
869
  )JIT";
870

871
  auto a = at::randn({2, 3});
872
  auto b = at::randn({4, 3, 2});
873

874
  std::vector<IValue> args{a};
875
  testStaticRuntime(sign_tensor, args);
876
  testStaticRuntime(sign_tensor, args, {b});
877
}
878

879
TEST(StaticRuntime, Div) {
880
  const auto div_tensor = R"JIT(
881
    def forward(self, a: Tensor, b: Tensor):
882
        return torch.div(a, b).clone()
883
  )JIT";
884

885
  const auto div_scalar = R"JIT(
886
    def forward(self, a: Tensor, b: int):
887
        return torch.div(a, b).clone()
888
  )JIT";
889

890
  const auto div_tensor_mode = R"JIT(
891
    def forward(self, a: Tensor, b: Tensor, c: str):
892
        return torch.div(a, b, rounding_mode=c).clone()
893
  )JIT";
894

895
  const auto div_scalar_mode = R"JIT(
896
    def forward(self, a: Tensor, b: float, c: str):
897
        return torch.div(a, b, rounding_mode=c).clone()
898
  )JIT";
899

900
  const auto div_strided = R"JIT(
901
    def forward(self, a: Tensor, b: Tensor):
902
        a_strided = torch.transpose(a, 0, 1)
903
        b_strided = torch.transpose(b, 0, 1)
904
        return torch.div(a_strided, b_strided).clone()
905
  )JIT";
906

907
  auto a = at::randn({2, 3});
908
  auto b = at::randn({2, 3});
909
  auto bs = at::randn({3, 2}).transpose(0, 1);
910
  auto c = at::randn({4, 3, 2});
911
  auto d = at::randn({4, 3, 2});
912
  auto ds = at::randn({3, 4, 2}).transpose(0, 1);
913

914
  std::vector<IValue> args0{a, b};
915
  testStaticRuntime(div_tensor, args0);
916
  testStaticRuntime(div_tensor, args0, {c, d});
917

918
  testStaticRuntime(div_strided, args0);
919
  testStaticRuntime(div_strided, args0, {c, d});
920

921
  testStaticRuntime(div_tensor, {a, bs});
922
  testStaticRuntime(div_tensor, {a, bs}, {c, ds});
923

924
  std::vector<IValue> args1{a, 3};
925
  testStaticRuntime(div_scalar, args1);
926
  testStaticRuntime(div_scalar, args1, {c, 4});
927

928
  std::vector<IValue> args2{a, b, "floor"};
929
  testStaticRuntime(div_tensor_mode, args2);
930
  testStaticRuntime(div_tensor_mode, args2, {c, d, "floor"});
931

932
  std::vector<IValue> args3{a, 2.3, "trunc"};
933
  testStaticRuntime(div_scalar_mode, args3);
934
  testStaticRuntime(div_scalar_mode, args3, {c, 1.5, "trunc"});
935
}
936

937
TEST(StaticRuntime, Mul) {
938
  const auto mul_tensor = R"JIT(
939
    def forward(self, a: Tensor, b: Tensor):
940
        return torch.mul(a, b).clone()
941
  )JIT";
942

943
  const auto mul_scalar = R"JIT(
944
    def forward(self, a: Tensor, b: int):
945
        return torch.mul(a, b).clone()
946
  )JIT";
947

948
  const auto mul_list = R"JIT(
949
    def forward(self, a: List[int], n: int):
950
        b = a * n
951
        return b[::]
952
  )JIT";
953

954
  auto a = at::randn({3, 3});
955
  auto b = at::randn({3, 3});
956
  auto c = at::randn({3, 3, 3});
957
  auto d = at::randn({3, 3, 3});
958

959
  std::vector<IValue> tensor_args1{a, b};
960
  std::vector<IValue> tensor_args2{c, d};
961

962
  testStaticRuntime(mul_tensor, tensor_args1);
963
  testStaticRuntime(mul_tensor, tensor_args1, tensor_args2);
964

965
  std::vector<IValue> scalar_args1{a, 42};
966
  std::vector<IValue> scalar_args2{c, 42};
967

968
  testStaticRuntime(mul_scalar, scalar_args1);
969
  testStaticRuntime(mul_scalar, scalar_args1, scalar_args2);
970

971
  std::vector<IValue> list_args{c10::List<int64_t>{1, 2}, 3};
972
  testStaticRuntime(mul_list, list_args);
973
}
974

975
TEST(StaticRuntime, Log) {
976
  const auto log_tensor = R"JIT(
977
    def forward(self, inp: Tensor):
978
        a = torch.log(inp).clone()
979
        return (a)
980
  )JIT";
981

982
  // Ensure that the input values are valid.
983
  auto a = at::abs(at::randn({2, 3}));
984
  auto b = at::abs(at::randn({4, 3, 2}));
985

986
  std::vector<IValue> args{a};
987
  testStaticRuntime(log_tensor, args);
988
  testStaticRuntime(log_tensor, args, {b});
989
}
990

991
TEST(StaticRuntime, Sub) {
992
  const auto sub_tensor = R"JIT(
993
    def forward(self, a: Tensor, b: Tensor):
994
        return torch.sub(a, b).clone()
995
  )JIT";
996

997
  const auto sub_scalar = R"JIT(
998
    def forward(self, a: Tensor, b: int):
999
        return torch.sub(a, b).clone()
1000
  )JIT";
1001

1002
  const auto sub_tensor_alpha = R"JIT(
1003
    def forward(self, a: Tensor, b: Tensor, c: float):
1004
        return torch.sub(a, b, alpha=c).clone()
1005
  )JIT";
1006

1007
  const auto sub_scalar_alpha = R"JIT(
1008
    def forward(self, a: Tensor, b: float, c: int):
1009
        return torch.sub(a, b, alpha=c).clone()
1010
  )JIT";
1011

1012
  const auto sub_two_scalars = R"JIT(
1013
    def forward(self, a: int, b: int):
1014
        return (a - b - b)
1015
  )JIT";
1016

1017
  auto a = at::randn({2, 3});
1018
  auto b = at::randn({2, 3});
1019
  auto c = at::randn({4, 3, 2});
1020
  auto d = at::randn({4, 3, 2});
1021

1022
  std::vector<IValue> args0{a, b};
1023
  testStaticRuntime(sub_tensor, args0);
1024
  testStaticRuntime(sub_tensor, args0, {c, d});
1025

1026
  std::vector<IValue> args1{a, 3};
1027
  testStaticRuntime(sub_scalar, args1);
1028
  testStaticRuntime(sub_scalar, args1, {c, 4});
1029

1030
  std::vector<IValue> args2{a, b, 2.3};
1031
  testStaticRuntime(sub_tensor_alpha, args2);
1032
  testStaticRuntime(sub_tensor_alpha, {c, d, 3.1});
1033

1034
  std::vector<IValue> args3{a, 2.3, 4};
1035
  testStaticRuntime(sub_scalar_alpha, args3);
1036
  testStaticRuntime(sub_scalar_alpha, {c, 1.3, 2});
1037

1038
  std::vector<IValue> args4{1, 2};
1039
  testStaticRuntime(sub_two_scalars, args4);
1040
}
1041

1042
TEST(StaticRuntime, NanToNum) {
1043
  const auto nan_to_num_script = R"JIT(
1044
    def forward(self, a: Tensor, nan: float, posinf: float, neginf: float):
1045
        return torch.nan_to_num(a, nan, posinf, neginf).clone()
1046
  )JIT";
1047

1048
  const auto inf = std::numeric_limits<double>::infinity();
1049
  const auto nan = std::numeric_limits<double>::quiet_NaN();
1050

1051
  auto a = torch::tensor({{1.0, nan}, {-inf, inf}});
1052
  auto b = at::randn({3, 6});
1053
  float* b_data = b.data_ptr<float>();
1054
  b_data[0] = nan;
1055
  b_data[4] = -inf;
1056
  b_data[11] = inf;
1057
  b_data[13] = nan;
1058

1059
  std::vector<IValue> args1{a, 1.0, 2.0, -2.0};
1060
  std::vector<IValue> args2{b, 1.0, 2.0, -2.0};
1061

1062
  testStaticRuntime(
1063
      nan_to_num_script,
1064
      args1,
1065
      /*args2*/ {},
1066
      /*use_allclose*/ true,
1067
      /*use_equalnan*/ true);
1068
  testStaticRuntime(
1069
      nan_to_num_script,
1070
      args1,
1071
      args2,
1072
      /*use_allclose*/ true,
1073
      /*use_equalnan*/ true);
1074
}
1075

1076
TEST(StaticRuntime, Stack) {
1077
  const auto stack_dim = R"JIT(
1078
    def forward(self, a: Tensor, b: Tensor, dim: int):
1079
        inputs = [a]
1080
        inputs.append(b) # mutation to avoid using VarStack
1081
        return torch.stack(inputs, dim = dim).clone()
1082
  )JIT";
1083

1084
  const auto stack_three = R"JIT(
1085
    def forward(self, a: Tensor, b: Tensor, c: Tensor):
1086
        inputs = [a, b]
1087
        inputs.append(c) # mutation to avoid using VarStack
1088
        return torch.stack(inputs).clone()
1089
  )JIT";
1090

1091
  auto a = at::randn({2, 2});
1092
  auto b = at::randn({2, 2});
1093
  auto c = at::randn({2, 2});
1094

1095
  auto d = at::randn({3, 3, 3});
1096
  auto e = at::randn({3, 3, 3});
1097
  auto f = at::randn({3, 3, 3});
1098

1099
  std::vector<IValue> args1_dim{a, b, 0};
1100
  std::vector<IValue> args2_dim{d, e, 1};
1101
  std::vector<IValue> args_dim_negative{d, e, -1};
1102

1103
  std::vector<IValue> args1_three_tensors{a, b, c};
1104
  std::vector<IValue> args2_three_tensors{d, e, f};
1105

1106
  testStaticRuntime(stack_dim, args1_dim);
1107
  testStaticRuntime(stack_dim, args1_dim, args2_dim);
1108

1109
  testStaticRuntime(stack_dim, args_dim_negative);
1110

1111
  testStaticRuntime(stack_three, args1_three_tensors);
1112
  testStaticRuntime(stack_three, args1_three_tensors, args2_three_tensors);
1113
}
1114

1115
TEST(StaticRuntime, ReLU) {
1116
  const auto relu_script = R"JIT(
1117
    def forward(self, a: Tensor):
1118
        return torch.relu(a).clone()
1119
  )JIT";
1120
  auto a = at::randint(-10, 10, {2, 4});
1121
  auto b = at::randint(-10, 10, {3, 6});
1122

1123
  std::vector<IValue> args1{a};
1124
  std::vector<IValue> args2{b};
1125

1126
  testStaticRuntime(relu_script, args1);
1127
  testStaticRuntime(relu_script, args1, args2);
1128
}
1129

1130
TEST(StaticRuntime, Tanh) {
1131
  const auto tanh_script = R"JIT(
1132
    def forward(self, a):
1133
        return torch.tanh(a).clone()
1134
  )JIT";
1135
  auto a = at::randn({2, 2});
1136
  auto b = at::randn({3, 3, 3});
1137

1138
  std::vector<IValue> args1{a};
1139
  std::vector<IValue> args2{b};
1140

1141
  testStaticRuntime(tanh_script, args1, /*args2*/ {}, /*use_allclose*/ true);
1142
  testStaticRuntime(tanh_script, args1, args2, /*use_allclose*/ true);
1143
}
1144

1145
TEST(StaticRuntime, Norm) {
1146
  const auto norm_2arg = R"JIT(
1147
    def forward(self, a: Tensor, p: int):
1148
        return torch.norm(a, p).clone()
1149
  )JIT";
1150

1151
  const auto norm_3arg = R"JIT(
1152
    def forward(self, a: Tensor, p: int, dtype: int):
1153
        return torch.norm(a, p, dtype=dtype).clone()
1154
  )JIT";
1155

1156
  const auto norm_4arg = R"JIT(
1157
    def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool):
1158
        return torch.norm(a, p, dim, keepdim).clone()
1159
  )JIT";
1160

1161
  const auto norm_5arg = R"JIT(
1162
    def forward(self, a: Tensor, p: int, dim: List[int], keepdim: bool, dtype: int):
1163
        return torch.norm(a, p, dim, keepdim, dtype=dtype).clone()
1164
  )JIT";
1165

1166
  auto a = at::randn({2, 3});
1167
  auto b = at::randn({4, 3, 5});
1168
  auto dim = std::vector<int64_t>({1});
1169
  auto dtype = at::ScalarType::Float;
1170

1171
  std::vector<IValue> args2{a, 2};
1172
  testStaticRuntime(norm_2arg, args2);
1173
  testStaticRuntime(norm_2arg, args2, {b, 2}, false, false, false);
1174

1175
  std::vector<IValue> args3{a, 2, dtype};
1176
  testStaticRuntime(norm_3arg, args3);
1177
  testStaticRuntime(norm_3arg, args3, {b, 2, dtype}, false, false, false);
1178

1179
  std::vector<IValue> args4{a, 3, dim, false};
1180
  testStaticRuntime(norm_4arg, args4);
1181
  testStaticRuntime(norm_4arg, args4, {b, 3, dim, false});
1182

1183
  std::vector<IValue> args5{a, 4, dim, true, dtype};
1184
  testStaticRuntime(norm_5arg, args5);
1185
  testStaticRuntime(norm_5arg, args5, {b, 4, dim, true, dtype});
1186
}
1187

1188
TEST(StaticRuntime, Reshape) {
1189
  const auto reshape_script_1 = R"JIT(
1190
    def forward(self, a: Tensor, shape: List[int]):
1191
        b = a.reshape(shape)
1192
        return b + b
1193
  )JIT";
1194

1195
  const auto reshape_script_2 = R"JIT(
1196
    def forward(self, a: Tensor, shape: List[int]):
1197
        b = a.transpose(0, 1)
1198
        return b.reshape(shape)
1199
  )JIT";
1200

1201
  const auto reshape_script_3 = R"JIT(
1202
    def forward(self, inp: Tensor, shape: List[int]):
1203
        a = inp + inp
1204
        b = a.reshape(shape)
1205
        c = a.reshape(shape)
1206
        d = c + c
1207
        e = d + d
1208
        f = e * e
1209
        g = f * f
1210
        return b.reshape(shape), g
1211
  )JIT";
1212

1213
  // exercise reshape_copy and flatten_copy
1214
  const auto reshape_script_4 = R"JIT(
1215
    def forward(self, inp: Tensor, shape: List[int]):
1216
        k = inp + inp
1217
        a = k + k
1218
        b = a.reshape(shape)
1219
        c = a.flatten().reshape(shape)
1220
        return b + c
1221
  )JIT";
1222

1223
  // exercise reshape_copy
1224
  const auto reshape_script_5 = R"JIT(
1225
    def forward(self, inp: Tensor, shape: List[int]):
1226
        a = inp + inp
1227
        b = a.reshape(shape)
1228
        c = a.reshape(shape).relu()
1229
        d = c + c
1230
        e = d + d
1231
        f = e * e
1232
        g = f * f
1233
        return g
1234
  )JIT";
1235

1236
  const auto reshape_inplace_script = R"JIT(
1237
    def forward(self, inp: Tensor, shape: List[int]):
1238
        a = inp + inp
1239
        b = a.reshape(shape)
1240
        c = b.sigmoid_()
1241
        d = c + c
1242
        e = a + a
1243
        f = b + b
1244
        return (d, e, f)
1245
  )JIT";
1246

1247
  // b is in_contiguous
1248
  const auto reshape_incontiguous_script = R"JIT(
1249
    def forward(self, a: Tensor, shape: List[int]):
1250
        b = a.transpose(0, 1)
1251
        c = b.reshape(shape)
1252
        c = c.relu()
1253
        return (c)
1254
  )JIT";
1255

1256
  auto a = at::randn({2, 3});
1257
  auto b = std::vector<int64_t>({3, 2});
1258
  std::vector<IValue> args{a, b};
1259

1260
  auto c = at::randn({4, 5});
1261
  auto d = std::vector<int64_t>({5, 1, 2, 2});
1262
  std::vector<IValue> args1{c, d};
1263

1264
  testStaticRuntime(reshape_script_1, args);
1265
  testStaticRuntime(reshape_script_2, args);
1266
  testStaticRuntime(reshape_script_3, args);
1267
  testStaticRuntime(reshape_script_4, args);
1268
  testStaticRuntime(reshape_script_5, args);
1269
  testStaticRuntime(reshape_inplace_script, args);
1270
  testStaticRuntime(reshape_incontiguous_script, args);
1271

1272
  testStaticRuntime(reshape_script_1, args, args1);
1273
  testStaticRuntime(reshape_script_2, args, args1);
1274
  testStaticRuntime(reshape_script_3, args, args1);
1275
  testStaticRuntime(reshape_script_4, args, args1);
1276
  testStaticRuntime(reshape_script_5, args, args1);
1277
  testStaticRuntime(reshape_inplace_script, args, args1);
1278
  testStaticRuntime(reshape_incontiguous_script, args, args1);
1279
}
1280

1281
TEST(StaticRuntime, Repeat) {
1282
  const std::string repeat = R"JIT(
1283
    def forward(self, a: Tensor, repeats: List[int]):
1284
        return torch.repeat(a, repeats).clone()
1285
  )JIT";
1286

1287
  auto a = at::randn({2, 3});
1288
  auto b = at::randn({4, 3});
1289
  auto c = std::vector<int64_t>({1, 2});
1290
  auto d = std::vector<int64_t>({2, 3});
1291
  std::vector<IValue> args1{a, c};
1292
  std::vector<IValue> args2{b, d};
1293

1294
  testStaticRuntime(repeat, args1);
1295
  testStaticRuntime(repeat, args2);
1296
  testStaticRuntime(repeat, args1, args2);
1297
}
1298

1299
TEST(StaticRuntime, Flatten) {
1300
  // exercise flatten_copy
1301
  const auto flatten_script_1 = R"JIT(
1302
    def forward(self, a: Tensor, start_dim: int, end_dim: int):
1303
        b = a * a
1304
        c = torch.flatten(b, start_dim, end_dim)
1305
        d = torch.relu(c)
1306
        return d
1307
  )JIT";
1308

1309
  const auto flatten_script_2 = R"JIT(
1310
    def forward(self, a: Tensor, start_dim: int, end_dim: int):
1311
        b = a.transpose(0, 1)
1312
        return torch.flatten(b, start_dim, end_dim).clone()
1313
  )JIT";
1314

1315
  auto test_flatten =
1316
      [&](std::vector<int64_t> shape, int64_t start_dim, int64_t end_dim) {
1317
        std::vector<int64_t> shape1(shape);
1318
        if (shape1.size() > 0) {
1319
          shape1[0] *= 6;
1320
        }
1321
        auto a = at::randn(shape);
1322
        auto b = at::randn(shape1);
1323
        std::vector<IValue> args{a, start_dim, end_dim};
1324
        bool check_resize = shape1.size() > 0;
1325
        testStaticRuntime(flatten_script_1, args);
1326
        testStaticRuntime(
1327
            flatten_script_1,
1328
            args,
1329
            {b, start_dim, end_dim},
1330
            false, /* use_allclose */
1331
            false, /* use_equalnan */
1332
            check_resize);
1333
        if (shape.size() > 2) {
1334
          testStaticRuntime(flatten_script_2, args);
1335
          testStaticRuntime(flatten_script_2, args, {b, start_dim, end_dim});
1336
        }
1337
      };
1338

1339
  test_flatten({2, 3}, 0, 1);
1340
  test_flatten({2, 1, 3}, 1, 2);
1341
  test_flatten({0, 1, 3, 0}, 1, 2);
1342
  test_flatten({2, 3}, 1, 1);
1343
  test_flatten({}, 0, 0);
1344
}
1345

1346
TEST(StaticRuntime, pow) {
1347
  const auto pow_script_ten_sca = R"JIT(
1348
    def forward(self, input : Tensor, exponent : int):
1349
        return torch.pow(input, exponent).clone()
1350
  )JIT";
1351

1352
  const auto pow_script_ten_ten = R"JIT(
1353
    def forward(self, input : Tensor, exponent : Tensor):
1354
        return torch.pow(input, exponent).clone()
1355
  )JIT";
1356

1357
  const auto pow_script_sca_ten = R"JIT(
1358
    def forward(self, input : int, exponent : Tensor):
1359
        return torch.pow(input, exponent).clone()
1360
  )JIT";
1361

1362
  auto a = at::randn({2, 3});
1363
  auto b = at::randn({2, 3});
1364
  auto c = at::randn({4, 3, 2});
1365
  auto d = at::randn({4, 3, 2});
1366

1367
  std::vector<IValue> args0{a, 4};
1368
  testStaticRuntime(pow_script_ten_sca, args0);
1369
  testStaticRuntime(pow_script_ten_sca, args0, {c, 4});
1370

1371
  std::vector<IValue> args1{at::abs(a), b};
1372
  testStaticRuntime(pow_script_ten_ten, args1);
1373
  testStaticRuntime(pow_script_ten_ten, args1, {at::abs(c), d});
1374

1375
  std::vector<IValue> args2{5, b};
1376
  testStaticRuntime(pow_script_sca_ten, args2);
1377
  testStaticRuntime(pow_script_sca_ten, args2, {3, d});
1378
}
1379

1380
TEST(StaticRuntime, to) {
1381
  const auto to_script_dtype = R"JIT(
1382
    def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
1383
        a = input + input
1384
        return torch.to(a, dtype, non_blocking, copy, memory_format).clone()
1385
  )JIT";
1386

1387
  const auto to_script_dtype_strided = R"JIT(
1388
    def forward(self, input: Tensor, dtype: int, non_blocking: bool, copy: bool, memory_format: int):
1389
        b = input.permute(0, 2, 3, 1)
1390
        return torch.to(b, dtype, non_blocking, copy, memory_format).clone()
1391
  )JIT";
1392

1393
  const auto to_script_prim_dtype = R"JIT(
1394
    def forward(self, input:Tensor, dtype: Optional[int], non_blocking: bool, copy: bool):
1395
        a = input + input
1396
        return torch.to(a, dtype, non_blocking, copy).clone()
1397
  )JIT";
1398

1399
  const auto to_script_other = R"JIT(
1400
    def forward(self, input:Tensor, other: Tensor, non_blocking: bool, copy: bool, memory_format: int):
1401
        a = input + input
1402
        return torch.to(a, other, non_blocking, copy, memory_format).clone()
1403
  )JIT";
1404

1405
  // if input is float tensor, b could be alias of a
1406
  const auto to_script_alias = R"JIT(
1407
    def forward(self, input:Tensor):
1408
        a = input + input
1409
        b = a.float()
1410
        c = b * b
1411
        return (c)
1412
  )JIT";
1413

1414
  const auto to_script_fails_managed_output_check = R"JIT(
1415
    def forward(self, a, b):
1416
        d = a.half() * b.half()
1417
        e = d.float()
1418
        return e
1419
  )JIT";
1420

1421
  const auto to_script_select_tensor_output_into_tuple = R"JIT(
1422
    def forward(self, a, b):
1423
        d = a.half() * b.half()
1424
        e = d.float()
1425
        return (d, e)
1426
  )JIT";
1427

1428
  const auto to_script_memory_planning_fail = R"JIT(
1429
    def forward(self, a, b):
1430
        d = a.half() * b.half()
1431
        e = d.float().relu()
1432
        return e
1433
  )JIT";
1434

1435
  auto test_to = [&](at::ScalarType b, bool c, bool d, c10::MemoryFormat e) {
1436
    auto a = at::randn({4, 3, 1, 2});
1437
    auto other = at::randn({4, 3, 1, 2}).to(b);
1438
    auto a2 = at::randn({3, 2, 2, 4});
1439
    auto a2_other = at::randn({3, 2, 2, 4}).to(b);
1440

1441
    std::vector<IValue> args0{a, b, c, d, e};
1442
    std::vector<IValue> args1{a, b, c, d};
1443
    std::vector<IValue> args2{a, other, c, d, e};
1444
    std::vector<IValue> args2WithDifferentOtherType{
1445
        a, at::randn({4, 3, 1, 2}, ScalarType::Double), c, d, e};
1446
    std::vector<IValue> args3{a, std::nullopt, c, d};
1447

1448
    std::vector<IValue> args0WithInt{a, ScalarType::Int, c, d, e};
1449
    testStaticRuntime(
1450
        to_script_dtype,
1451
        args0,
1452
        args0WithInt,
1453
        /* default for use_allclose */ false,
1454
        /* default for use_equalnan */ false,
1455
        /* check_resize */ false);
1456
    testStaticRuntime(to_script_dtype_strided, args0);
1457
    testStaticRuntime(to_script_prim_dtype, args1);
1458
    if (!d) {
1459
      testStaticRuntime(to_script_prim_dtype, args3);
1460
    }
1461
    // Second set of args tests case where the `other` tensor's dtype
1462
    // changes between iterations.
1463
    testStaticRuntime(
1464
        to_script_other,
1465
        args2,
1466
        args2WithDifferentOtherType,
1467
        /* default for use_allclose */ false,
1468
        /* default for use_equalnan */ false,
1469
        /* check_resize */ false);
1470
    testStaticRuntime(to_script_alias, {a});
1471

1472
    testStaticRuntime(to_script_memory_planning_fail, {a, a});
1473
    testStaticRuntime(to_script_fails_managed_output_check, {a, a});
1474
    testStaticRuntime(to_script_select_tensor_output_into_tuple, {a, a});
1475

1476
    // dynamic shapes
1477
    testStaticRuntime(to_script_dtype, args0, {a2, b, c, d, e});
1478
    testStaticRuntime(to_script_dtype_strided, args0, {a2, b, c, d, e});
1479
    testStaticRuntime(to_script_prim_dtype, args1, {a2, b, c, d});
1480
    if (!d) {
1481
      testStaticRuntime(to_script_prim_dtype, args3, {a2, std::nullopt, c, d});
1482
    }
1483
    testStaticRuntime(to_script_other, args2, {a2, a2_other, c, d, e});
1484
    testStaticRuntime(to_script_alias, {a}, {a2});
1485
  };
1486
  for (const bool non_blocking : {false, true}) {
1487
    for (const bool copy : {false, true}) {
1488
      // float->float, NCHW->NHWC
1489
      test_to(
1490
          at::ScalarType::Float,
1491
          non_blocking,
1492
          copy,
1493
          c10::MemoryFormat::ChannelsLast);
1494
      // float->half
1495
      test_to(
1496
          at::ScalarType::Half,
1497
          non_blocking,
1498
          copy,
1499
          c10::MemoryFormat::Preserve);
1500
      // float->float
1501
      test_to(
1502
          at::ScalarType::Float,
1503
          non_blocking,
1504
          copy,
1505
          c10::MemoryFormat::Contiguous);
1506
      test_to(
1507
          at::ScalarType::Bool,
1508
          non_blocking,
1509
          copy,
1510
          c10::MemoryFormat::Contiguous);
1511
      // TODO: check if fbgemm is enabled properly in this case
1512
      // half->float, NCHW->NHWC
1513
      test_to(
1514
          at::ScalarType::Half,
1515
          non_blocking,
1516
          copy,
1517
          c10::MemoryFormat::ChannelsLast);
1518
    }
1519
  }
1520
}
1521

1522
TEST(StaticRuntime, ExpandAs) {
1523
  const auto expand_as_script = R"JIT(
1524
    def forward(self, input: Tensor, other:Tensor):
1525
        a = input.expand_as(other)
1526
        return a.clone()
1527
  )JIT";
1528

1529
  auto a = at::randn({3, 1});
1530
  auto b = at::randn({3, 2});
1531
  auto c = at::randn({4, 1});
1532
  auto d = at::randn({4, 2});
1533
  std::vector<IValue> args{a, b};
1534
  std::vector<IValue> args2{c, d};
1535
  testStaticRuntime(expand_as_script, args);
1536
  testStaticRuntime(expand_as_script, args, args2);
1537
}
1538

1539
TEST(StaticRuntime, Full) {
1540
  const auto full_script = R"JIT(
1541
    def forward(self,
1542
                size: List[int],
1543
                fill_value: int,
1544
                dtype: Optional[int],
1545
                layout: Optional[int],
1546
                device: Optional[Device],
1547
                pin_memory: Optional[bool]):
1548
        a = torch.full(size,
1549
                      fill_value,
1550
                      dtype=dtype,
1551
                      layout=layout,
1552
                      device=device,
1553
                      pin_memory=pin_memory)
1554
        return (a.clone())
1555
  )JIT";
1556

1557
  auto cpu = at::Device(DeviceType::CPU);
1558
  c10::List<int64_t> size0{2, 5};
1559
  std::vector<IValue> args{
1560
      size0, 4, at::ScalarType::Int, at::kStrided, cpu, false};
1561
  std::vector<IValue> args1{
1562
      size0, 4, at::ScalarType::Float, at::kStrided, cpu, false};
1563
  c10::List<int64_t> size1{5, 6};
1564
  std::vector<IValue> args2{
1565
      size1, 5, at::ScalarType::Float, at::kStrided, cpu, false};
1566
  testStaticRuntime(full_script, args);
1567
  testStaticRuntime(
1568
      full_script,
1569
      args,
1570
      args1,
1571
      /*use_allclose=*/false,
1572
      /*use_equalnan=*/false,
1573
      /*check_resize=*/false);
1574
  testStaticRuntime(full_script, args, args2);
1575
}
1576

1577
TEST(StaticRuntime, FullLike) {
1578
  const auto full_like_script = R"JIT(
1579
    def forward(self,
1580
                a: Tensor,
1581
                fill_value: int,
1582
                dtype: Optional[int],
1583
                layout: Optional[int],
1584
                device: Optional[Device],
1585
                pin_memory: Optional[bool],
1586
                memory_format: Optional[int]):
1587
        b = torch.full_like(a,
1588
                            fill_value,
1589
                            dtype=dtype,
1590
                            layout=layout,
1591
                            device=device,
1592
                            pin_memory=pin_memory,
1593
                            memory_format=memory_format)
1594
        return (b.clone())
1595
  )JIT";
1596

1597
  auto a = at::randn({2, 3});
1598
  auto b = at::randn({3, 4, 2});
1599
  auto cpu = at::Device(DeviceType::CPU);
1600
  std::vector<IValue> args{
1601
      a,
1602
      4,
1603
      at::ScalarType::Int,
1604
      at::kStrided,
1605
      cpu,
1606
      false,
1607
      c10::MemoryFormat::Contiguous};
1608
  std::vector<IValue> args1{
1609
      a,
1610
      4,
1611
      at::ScalarType::Float,
1612
      at::kStrided,
1613
      cpu,
1614
      false,
1615
      c10::MemoryFormat::Contiguous};
1616
  std::vector<IValue> args2{
1617
      b,
1618
      4,
1619
      at::ScalarType::Float,
1620
      at::kStrided,
1621
      cpu,
1622
      false,
1623
      c10::MemoryFormat::Contiguous};
1624
  testStaticRuntime(full_like_script, args);
1625
  testStaticRuntime(
1626
      full_like_script,
1627
      args,
1628
      args1,
1629
      /*use_allclose=*/false,
1630
      /*use_equalnan=*/false,
1631
      /*check_resize=*/false);
1632
  testStaticRuntime(full_like_script, args, args2);
1633
}
1634

1635
TEST(StaticRuntime, Ones) {
1636
  const auto script = R"JIT(
1637
    def forward(self,
1638
                size: List[int],
1639
                dtype: Optional[int],
1640
                layout: Optional[int],
1641
                device: Optional[Device],
1642
                pin_memory: Optional[bool]):
1643
        a = torch.ones(size,
1644
                       dtype=dtype,
1645
                       layout=layout,
1646
                       device=device,
1647
                       pin_memory=pin_memory)
1648
        return (a.clone())
1649
  )JIT";
1650

1651
  auto dtype = at::ScalarType::Int;
1652
  auto cpu = at::Device(DeviceType::CPU);
1653
  c10::List<int64_t> size0{2, 5};
1654
  std::vector<IValue> args{size0, dtype, at::kStrided, cpu, false};
1655
  c10::List<int64_t> size1{5, 6};
1656
  std::vector<IValue> args2{size1, dtype, at::kStrided, cpu, false};
1657
  testStaticRuntime(script, args);
1658
  testStaticRuntime(script, args, args2);
1659
}
1660

1661
TEST(StaticRuntime, OnesLike) {
1662
  const auto script = R"JIT(
1663
    def forward(self,
1664
                input: Tensor,
1665
                dtype: Optional[int],
1666
                layout: Optional[int],
1667
                device: Optional[Device],
1668
                pin_memory: Optional[bool],
1669
                memory_format: Optional[int]):
1670
        a = torch.ones_like(input,
1671
                            dtype=dtype,
1672
                            layout=layout,
1673
                            device=device,
1674
                            pin_memory=pin_memory,
1675
                            memory_format=memory_format)
1676
        return (a.clone())
1677
  )JIT";
1678

1679
  auto cpu = at::Device(DeviceType::CPU);
1680
  auto input0 = at::randn({2, 5});
1681
  std::vector<IValue> args{
1682
      input0,
1683
      at::ScalarType::Int,
1684
      at::kStrided,
1685
      cpu,
1686
      false,
1687
      c10::MemoryFormat::Contiguous};
1688
  std::vector<IValue> args1{
1689
      input0,
1690
      at::ScalarType::Float,
1691
      at::kStrided,
1692
      cpu,
1693
      false,
1694
      c10::MemoryFormat::Contiguous};
1695
  auto input1 = at::randn({5, 6});
1696
  std::vector<IValue> args2{
1697
      input1,
1698
      at::ScalarType::Float,
1699
      at::kStrided,
1700
      cpu,
1701
      false,
1702
      c10::MemoryFormat::Contiguous};
1703
  testStaticRuntime(script, args);
1704
  testStaticRuntime(
1705
      script,
1706
      args,
1707
      args1,
1708
      /*use_allclose=*/false,
1709
      /*use_equalnan=*/false,
1710
      /*check_resize=*/false);
1711
  testStaticRuntime(script, args, args2);
1712
}
1713

1714
TEST(StaticRuntime, Zeros) {
1715
  const auto script = R"JIT(
1716
    def forward(self,
1717
                size: List[int],
1718
                dtype: Optional[int],
1719
                layout: Optional[int],
1720
                device: Optional[Device],
1721
                pin_memory: Optional[bool]):
1722
        a = torch.zeros(size,
1723
                       dtype=dtype,
1724
                       layout=layout,
1725
                       device=device,
1726
                       pin_memory=pin_memory)
1727
        return (a.clone())
1728
  )JIT";
1729

1730
  auto cpu = at::Device(DeviceType::CPU);
1731
  c10::List<int64_t> size0{2, 5};
1732
  std::vector<IValue> args{
1733
      size0, at::ScalarType::Int, at::kStrided, cpu, false};
1734
  std::vector<IValue> args1{
1735
      size0, at::ScalarType::Float, at::kStrided, cpu, false};
1736
  c10::List<int64_t> size1{5, 6};
1737
  std::vector<IValue> args2{
1738
      size1, at::ScalarType::Float, at::kStrided, cpu, false};
1739
  testStaticRuntime(script, args);
1740
  testStaticRuntime(
1741
      script,
1742
      args,
1743
      args1,
1744
      /*use_allclose=*/false,
1745
      /*use_equalnan=*/false,
1746
      /*check_resize=*/false);
1747
  testStaticRuntime(script, args, args2);
1748
}
1749

1750
TEST(StaticRuntime, Linear) {
1751
  const auto linear_script = R"JIT(
1752
    def forward(self, inp: Tensor, weights: Tensor, bias: Optional[Tensor]) -> Tensor:
1753
        return torch.linear(inp, weights, bias).clone()
1754
  )JIT";
1755

1756
  auto input = at::randn({1, 2});
1757
  auto weights = at::randn({1, 2});
1758
  auto bias = at::randn({1, 1});
1759

1760
  std::vector<IValue> args{input, weights, bias};
1761
  std::vector<IValue> args_no_bias{input, weights, std::nullopt};
1762

1763
  auto input2 = at::randn({6, 3});
1764
  auto weights2 = at::randn({6, 3});
1765
  auto bias2 = at::randn({6, 6});
1766

1767
  std::vector<IValue> args2{input2, weights2, bias2};
1768
  std::vector<IValue> args2_no_bias{input2, weights2, std::nullopt};
1769

1770
  testStaticRuntime(linear_script, args);
1771
  testStaticRuntime(linear_script, args_no_bias);
1772

1773
  testStaticRuntime(linear_script, args, args2);
1774
  testStaticRuntime(linear_script, args, args2_no_bias);
1775
}
1776

1777
TEST(StaticRuntime, VarCat) {
1778
  const auto var_cat_script = R"JIT(
1779
    def forward(self, inp1: Tensor, inp2: Tensor, dim: int):
1780
      return torch.cat([inp1, inp2], dim).clone()
1781
  )JIT";
1782

1783
  // 2D tensors - cat dim = 0
1784
  std::vector<IValue> args1 = {at::randn({4, 6}), at::randn({5, 6}), 0};
1785
  testStaticRuntime(var_cat_script, args1);
1786

1787
  // 3D tensors - cat dim = 1
1788
  std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 8, 6}), 1};
1789
  testStaticRuntime(var_cat_script, args2);
1790

1791
  // 3D tensors - cat dim = 2
1792
  std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), 2};
1793
  testStaticRuntime(var_cat_script, args3);
1794

1795
  // Negative dim
1796
  std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 7}), -1};
1797
  testStaticRuntime(var_cat_script, args4);
1798

1799
  testStaticRuntime(var_cat_script, args1, args2);
1800
}
1801

1802
TEST(StaticRuntime, LeakyReLU) {
1803
  torch::jit::Module mod = getLeakyReLUConstScriptModel();
1804
  auto inputs = torch::randn({2, 2});
1805

1806
  // run jit graph executor
1807
  std::vector<at::IValue> input_ivalues({inputs});
1808
  at::Tensor output_1 = mod.forward(input_ivalues).toTensor();
1809

1810
  // run static runtime
1811
  std::vector<c10::IValue> input_tensors({inputs});
1812
  torch::jit::StaticModule smod(mod);
1813
  at::Tensor output_2 = smod(input_tensors, {}).toTensor();
1814
  smod.runtime().check_for_memory_leak();
1815
  EXPECT_TRUE(torch::allclose(output_1, output_2, 1e-6));
1816
}
1817

1818
static ProcessedNodeInputs createProcessedNodeInputs(
1819
    c10::ArrayRef<uint16_t> inputs) {
1820
  ProcessedNodeInputs result(inputs.size());
1821
  for (const auto idx : c10::irange(inputs.size())) {
1822
    result[idx] = inputs[idx];
1823
  }
1824
  return result;
1825
}
1826

1827
static void checkProcessedNodeInputs(
1828
    const ProcessedNodeInputs& io,
1829
    c10::ArrayRef<uint16_t> inputs) {
1830
  ASSERT_EQ(inputs.size(), io.size());
1831
  for (const auto idx : c10::irange(inputs.size())) {
1832
    EXPECT_EQ(inputs[idx], io[idx]);
1833
  }
1834
}
1835

1836
static void testProcessedNodeInputsRoundTrip(c10::ArrayRef<uint16_t> inputs) {
1837
  auto io = createProcessedNodeInputs(inputs);
1838
  checkProcessedNodeInputs(io, inputs);
1839

1840
  ProcessedNodeInputs copied(io);
1841
  checkProcessedNodeInputs(copied, inputs);
1842
  ProcessedNodeInputs moved(std::move(io));
1843
  checkProcessedNodeInputs(moved, inputs);
1844
}
1845

1846
TEST(ProcessedNodeInputs, Basic) {
1847
  std::vector<std::vector<uint16_t>> testCases = {
1848
      {}, // empty
1849
      {0xABCD, 0x5a5a}, // inline
1850
      {0x11, 0x22, 0x33, 0x44, 0x55}, // max inline size
1851
      {0x11, 0x22, 0x33, 0x44, 0x55, 0x66}, // minimum outline size
1852
      std::vector<uint16_t>(100, 0x5a), // large outline size
1853
  };
1854

1855
  for (const auto& values : testCases) {
1856
    testProcessedNodeInputsRoundTrip(values);
1857
    for (const auto& values2 : testCases) {
1858
      auto from = createProcessedNodeInputs(values);
1859
      auto to = createProcessedNodeInputs(values2);
1860

1861
      to = from;
1862
      checkProcessedNodeInputs(to, values);
1863

1864
      auto toMoveInto = createProcessedNodeInputs(values2);
1865
      toMoveInto = std::move(from);
1866
      checkProcessedNodeInputs(toMoveInto, values);
1867
    }
1868
  }
1869
}
1870

1871
TEST(StaticRuntime, isinstance) {
1872
  const auto isinstance_int_script = R"JIT(
1873
    def forward(self, a: Any):
1874
        return isinstance(a, int)
1875
  )JIT";
1876

1877
  const auto isinstance_tensor_script = R"JIT(
1878
    def forward(self, a: Any):
1879
        return isinstance(a, torch.Tensor)
1880
  )JIT";
1881

1882
  const auto isinstance_many_types_script = R"JIT(
1883
    def forward(self, a: Any):
1884
        return isinstance(a, (bool, int))
1885
  )JIT";
1886

1887
  auto a = at::randn({2, 2});
1888
  auto b = at::randn({2, 2, 2});
1889

1890
  std::vector<at::IValue> args{a};
1891
  std::vector<at::IValue> args2{b};
1892

1893
  testStaticRuntime(isinstance_int_script, args);
1894
  testStaticRuntime(isinstance_int_script, args, args2);
1895

1896
  testStaticRuntime(isinstance_tensor_script, args);
1897
  testStaticRuntime(isinstance_tensor_script, args, args2);
1898

1899
  testStaticRuntime(isinstance_many_types_script, args);
1900
  testStaticRuntime(isinstance_many_types_script, args, args2);
1901
}
1902

1903
TEST(StaticRuntime, TypeCheck) {
1904
  const auto typecheck_ir = R"IR(
1905
  graph(%a.1 : Tensor,
1906
        %b.1 : Tensor):
1907
    %t0 : Float(2, 2, strides=[2, 1], device=cpu), %t1 : Float(3, 3, strides=[3, 1]), %type_matched : bool = prim::TypeCheck[types=[Float(2, 2, strides=[2, 1], device=cpu), Float(3, 3, strides=[3, 1])]](%a.1, %b.1)
1908
    return (%t0, %t1, %type_matched)
1909
  )IR";
1910

1911
  auto a = at::zeros({2, 2}, at::kFloat);
1912
  a.to(at::kCPU);
1913
  auto b = at::ones({3, 3}, at::kFloat);
1914
  auto c = at::ones({2, 2, 2}, at::kFloat);
1915

1916
  std::vector<IValue> args_correct = {a, b};
1917
  std::vector<IValue> args_incorrect = {a, c};
1918

1919
  testStaticRuntime(typecheck_ir, args_correct);
1920
  testStaticRuntime(typecheck_ir, args_correct, args_incorrect);
1921
}
1922

1923
TEST(StaticRuntime, Index) {
1924
  const auto index_without_none_script = R"JIT(
1925
    def forward(self, a: Tensor, idx: Tensor):
1926
        return a[idx].clone()
1927
  )JIT";
1928

1929
  // Index with boolean mask
1930
  auto a = at::arange(4, at::kFloat).view({2, 2});
1931
  auto idx_a = torch::tensor({{0, 1}, {0, 0}}, at::kBool);
1932
  std::vector<IValue> args_a{a, idx_a};
1933

1934
  // Index with tensor
1935
  auto b = at::arange(27, at::kFloat).view({3, 3, 3});
1936
  auto idx_b = torch::tensor({0, 1, 2}, at::kLong);
1937
  std::vector<IValue> args_b{b, idx_b};
1938

1939
  testStaticRuntime(index_without_none_script, args_a);
1940
  testStaticRuntime(index_without_none_script, args_a, args_b);
1941

1942
  const auto index_with_none_script = R"JIT(
1943
    def forward(self, a: Tensor, idx: Tensor, none: Optional[Tensor]):
1944
        return a[idx, none].clone()
1945
  )JIT";
1946

1947
  // Index with None
1948
  // When indexing with none, the shape of `f` becomes [2, 1, 2],
1949
  // so the mask must be reshaped appropriately.
1950
  auto f = at::arange(4, at::kFloat).view({2, 1, 2});
1951
  auto idx_f_reshape = torch::tensor({{{0, 1}}, {{0, 0}}}, at::kBool);
1952
  std::vector<IValue> args_f_with_none{f, idx_f_reshape};
1953
  args_f_with_none.emplace_back();
1954

1955
  testStaticRuntime(index_with_none_script, args_f_with_none);
1956
  testStaticRuntime(
1957
      index_with_none_script,
1958
      args_f_with_none,
1959
      {IValue(b), IValue(idx_b), IValue()});
1960

1961
  const auto index_with_two_tensors_script = R"JIT(
1962
    def forward(self, a: Tensor, idx_a: Tensor, idx_b: Tensor):
1963
        return a[idx_a, idx_b].clone()
1964
  )JIT";
1965

1966
  // Index with multiple tensors
1967
  const auto& c = a; // 2x2 tensor
1968
  auto idx_c1 = torch::tensor({0, 0}, at::kLong);
1969
  auto idx_c2 = torch::tensor({0}, at::kLong);
1970
  std::vector<IValue> args_c{c, idx_c1, idx_c2};
1971

1972
  const auto& d = b; // 3x3x3 tensor
1973
  auto idx_d1 = torch::tensor({{0, 0, 2}, {0, 1, 1}}, at::kLong);
1974
  auto idx_d2 = torch::tensor({{1, 1, 0}, {1, 0, 2}}, at::kLong);
1975
  std::vector<IValue> args_d{d, idx_d1, idx_d2};
1976

1977
  testStaticRuntime(index_with_two_tensors_script, args_c, args_d);
1978
}
1979

1980
TEST(StaticRuntime, IndexSelect) {
1981
  const std::string script = R"IR(
1982
    graph(%self: Tensor, %dim: int, %index: Tensor):
1983
        %bias: None = prim::Constant()
1984
        %ret = aten::index_select(%self, %dim, %index)
1985
        %cloned = aten::clone(%ret, %bias)
1986
        return (%cloned)
1987
  )IR";
1988

1989
  auto self0 = at::rand({6});
1990
  auto dim0 = 0;
1991
  auto index0 = at::randint(0, 5, {6}, torch::kInt32);
1992
  std::vector<IValue> args{self0, dim0, index0};
1993
  testStaticRuntime(script, args);
1994

1995
  auto self1 = at::rand({128});
1996
  auto dim1 = 0;
1997
  auto index1 = at::randint(0, 127, {127}, torch::kInt32);
1998
  std::vector<IValue> args2{self1, dim1, index1};
1999
  testStaticRuntime(script, args, args2);
2000
}
2001

2002
TEST(StaticRuntime, ClampMin) {
2003
  const auto clamp_min_int_script = R"JIT(
2004
    def forward(self, a: Tensor, b: int):
2005
        return torch.clamp_min(a, b).clone()
2006
  )JIT";
2007

2008
  const auto clamp_min_float_script = R"JIT(
2009
    def forward(self, a: Tensor, b: float):
2010
        return torch.clamp_min(a, b).clone()
2011
  )JIT";
2012

2013
  auto a = at::randn({2, 2});
2014
  auto b = at::randn({3, 3, 3});
2015
  int scalar_int = 1;
2016
  float scalar_float = 3.14;
2017

2018
  std::vector<IValue> args_a_int{a, scalar_int};
2019
  std::vector<IValue> args_b_int{b, scalar_int};
2020

2021
  testStaticRuntime(clamp_min_int_script, args_a_int);
2022
  testStaticRuntime(clamp_min_int_script, args_a_int, args_b_int);
2023

2024
  std::vector<IValue> args_a_float{a, scalar_float};
2025
  std::vector<IValue> args_b_float{b, scalar_float};
2026

2027
  testStaticRuntime(clamp_min_float_script, args_a_float);
2028
  testStaticRuntime(clamp_min_float_script, args_a_float, args_b_float);
2029
}
2030

2031
TEST(StaticRuntime, Argmin) {
2032
  const auto argmin_script = R"JIT(
2033
    def forward(self, a: Tensor):
2034
        return torch.argmin(a).clone()
2035
  )JIT";
2036

2037
  const auto argmin_with_dim_script = R"JIT(
2038
    def forward(self, a: Tensor, dim: int):
2039
        return torch.argmin(a, dim).clone()
2040
  )JIT";
2041

2042
  const auto argmin_with_keep_dim_script = R"JIT(
2043
    def forward(self, a: Tensor, dim: int):
2044
        return torch.argmin(a, dim, True).clone()
2045
  )JIT";
2046

2047
  auto a = at::randn({2, 2});
2048
  auto b = at::randn({17, 2, 1});
2049

2050
  testStaticRuntime(argmin_script, {a});
2051
  testStaticRuntime(
2052
      argmin_script,
2053
      {a},
2054
      {b},
2055
      /* use_allclose */ false,
2056
      /* use_equalnan */ false,
2057
      /* check_resize */ false);
2058

2059
  int dim_a = 0;
2060
  int dim_b = 1;
2061

2062
  std::vector<IValue> args_a{a, dim_a};
2063
  std::vector<IValue> args_b{b, dim_b};
2064

2065
  testStaticRuntime(argmin_with_dim_script, args_a);
2066
  testStaticRuntime(argmin_with_dim_script, args_a, args_b);
2067

2068
  testStaticRuntime(argmin_with_keep_dim_script, args_a);
2069
  testStaticRuntime(argmin_with_keep_dim_script, args_a, args_b);
2070
}
2071

2072
TEST(StaticRuntime, Softmax) {
2073
  const auto softmax_script = R"JIT(
2074
    def forward(self, a: Tensor, dim: int):
2075
        return torch.softmax(a, dim).clone()
2076
  )JIT";
2077

2078
  const auto softmax_script_with_dtype = R"JIT(
2079
    def forward(self, a: Tensor, dim: int, dtype: int):
2080
        return torch.softmax(a, dim, dtype=dtype).clone()
2081
  )JIT";
2082

2083
  auto a = at::randn({2, 3});
2084
  auto b = at::randn({3, 3, 3});
2085

2086
  testStaticRuntime(softmax_script, {a, 0});
2087
  testStaticRuntime(softmax_script, {a, 1});
2088

2089
  testStaticRuntime(softmax_script, {b, 0});
2090
  testStaticRuntime(softmax_script, {b, 1});
2091
  testStaticRuntime(softmax_script, {b, 2});
2092

2093
  testStaticRuntime(softmax_script_with_dtype, {a, 1, at::ScalarType::Float});
2094
  testStaticRuntime(softmax_script_with_dtype, {b, 1, at::ScalarType::Float});
2095
}
2096

2097
TEST(StaticRuntime, GetItem_Dict) {
2098
  const auto getitem_dict_tensor_script = R"JIT(
2099
    def forward(self, key: Tensor):
2100
        d = {key: 1}
2101
        return d[key]
2102
  )JIT";
2103

2104
  const auto getitem_dict_int_script = R"JIT(
2105
    def forward(self, key: int):
2106
        d = {key: 1}
2107
        return d[key]
2108
  )JIT";
2109

2110
  const auto getitem_dict_str_script = R"JIT(
2111
    def forward(self, key: str):
2112
        d = {key: 1}
2113
        return d[key]
2114
  )JIT";
2115

2116
  int int_key = 0;
2117
  std::string str_key = "str";
2118

2119
  // No need to test these multiple times, args are not tensors
2120
  testStaticRuntime(getitem_dict_int_script, {int_key});
2121
  testStaticRuntime(getitem_dict_str_script, {str_key});
2122

2123
  auto a = torch::tensor({1});
2124
  auto b = torch::tensor({1, 1});
2125

2126
  testStaticRuntime(getitem_dict_tensor_script, {a});
2127
  testStaticRuntime(getitem_dict_tensor_script, {a}, {b});
2128
}
2129

2130
TEST(StaticRuntime, GetItem_List) {
2131
  const auto getitem_list_int_script = R"JIT(
2132
    def forward(self, idx: int):
2133
        lst = [1, 2, 3]
2134
        return lst[idx]
2135
  )JIT";
2136

2137
  const auto getitem_list_tensor_script = R"JIT(
2138
    def forward(self, tensor: Tensor, idx: int):
2139
        lst = [tensor, tensor]
2140
        return lst[idx]
2141
  )JIT";
2142

2143
  testStaticRuntime(getitem_list_int_script, {1});
2144
  testStaticRuntime(getitem_list_int_script, {-1});
2145

2146
  auto a = torch::tensor({1});
2147
  auto b = torch::tensor({1, 1});
2148

2149
  testStaticRuntime(getitem_list_tensor_script, {a, 1});
2150
  testStaticRuntime(getitem_list_tensor_script, {a, 1}, {b, -1});
2151
}
2152

2153
TEST(StaticRuntime, Transpose) {
2154
  const auto transpose_script = R"JIT(
2155
    def forward(self, a: Tensor, dim1: int, dim2: int):
2156
        return torch.transpose(a, dim1, dim2).clone()
2157
  )JIT";
2158

2159
  auto a = at::randn({2, 2});
2160
  int dim1_a = 0;
2161
  int dim2_a = 1;
2162
  std::vector<IValue> args_a{a, dim1_a, dim2_a};
2163

2164
  auto b = at::randn({3, 3, 3});
2165
  int dim1_b = 0;
2166
  int dim2_b = 2;
2167
  std::vector<IValue> args_b{b, dim1_b, dim2_b};
2168

2169
  testStaticRuntime(transpose_script, args_a);
2170
  testStaticRuntime(transpose_script, args_a, args_b);
2171
}
2172

2173
TEST(StaticRuntime, Permute) {
2174
  auto permute_script = R"JIT(
2175
    def forward(self, a: Tensor, dims: List[int]):
2176
        return torch.permute(a, dims).clone()
2177
  )JIT";
2178

2179
  auto a = at::randn({2, 2});
2180
  c10::List<int64_t> dims_a{1, 0};
2181
  std::vector<IValue> args_a{a, dims_a};
2182

2183
  auto b = at::randn({3, 3, 3});
2184
  c10::List<int64_t> dims_b{0, 2, 1};
2185
  std::vector<IValue> args_b{b, dims_b};
2186

2187
  testStaticRuntime(permute_script, args_a);
2188
  testStaticRuntime(permute_script, args_a, args_b);
2189

2190
  permute_script = R"JIT(
2191
    def forward(self, a: Tensor, dims: List[int], shape: List[int]):
2192
        return torch.permute(a, dims).reshape(shape).clone()
2193
  )JIT";
2194

2195
  a = at::randn({8, 16, 4});
2196
  dims_a = {0, 2, 1};
2197
  dims_b = {-1, 16};
2198
  testStaticRuntime(permute_script, {a, dims_a, dims_b});
2199
}
2200

2201
TEST(StaticRuntime, Slice) {
2202
  const auto slice_script = R"JIT(
2203
    def forward(self, a: Tensor, dim: int, start: int, end: int, step: int):
2204
      return a.slice(dim, start, end, step).clone()
2205
  )JIT";
2206

2207
  auto a = at::randn({2, 2});
2208
  int dim_a = 1;
2209
  int start_a = 0;
2210
  int end_a = 1;
2211
  int step_a = 1;
2212
  std::vector<IValue> args_a{a, dim_a, start_a, end_a, step_a};
2213

2214
  auto b = at::randn({3, 3, 3});
2215
  int dim_b = 2;
2216
  int start_b = 0;
2217
  int end_b = 1;
2218
  int step_b = 2;
2219
  std::vector<IValue> args_b{b, dim_b, start_b, end_b, step_b};
2220

2221
  testStaticRuntime(slice_script, args_a);
2222
  testStaticRuntime(slice_script, args_a, args_b);
2223

2224
  const auto slice_script2 = R"JIT(
2225
    def forward(self, a: Tensor, dim: int, step: int):
2226
      return a.slice(dim, None, None, step).clone()
2227
  )JIT";
2228
  std::vector<IValue> args_c{b, dim_b, step_b};
2229
  testStaticRuntime(slice_script2, args_c);
2230
}
2231

2232
TEST(StaticRuntime, Narrow) {
2233
  const auto narrow_with_int_script = R"JIT(
2234
    def forward(self, a: Tensor, dim: int, start: int, length: int):
2235
        return a.narrow(dim, start, length).clone()
2236
  )JIT";
2237

2238
  auto a = at::randn({5, 5});
2239
  int dim_a = 0;
2240
  int start_a_int = 3;
2241
  int len_a = 2;
2242
  std::vector<IValue> args_a{a, dim_a, start_a_int, len_a};
2243

2244
  auto b = at::randn({5, 5, 5});
2245
  int dim_b = 1;
2246
  int start_b_int = 2;
2247
  int len_b = 3;
2248
  std::vector<IValue> args_b{b, dim_b, start_b_int, len_b};
2249

2250
  testStaticRuntime(narrow_with_int_script, args_a);
2251
  testStaticRuntime(narrow_with_int_script, args_a, args_b);
2252
}
2253

2254
TEST(StaticRuntime, TupleUnpack) {
2255
  const auto two_tuple_unpack_script = R"JIT(
2256
    def forward(self, tup: Tuple[Tensor, Tensor]):
2257
        a, b = tup
2258
        return (a, b)
2259
  )JIT";
2260

2261
  const auto three_tuple_unpack_script = R"JIT(
2262
    def forward(self, tup: Tuple[Tensor, Tensor, Tensor]):
2263
        a, b, c = tup
2264
        return (a, b, c)
2265
  )JIT";
2266

2267
  auto two_tup = c10::ivalue::Tuple::create({at::randn({1}), at::randn({1})});
2268
  auto two_tup_large =
2269
      c10::ivalue::Tuple::create({at::randn({2, 2}), at::randn({2, 2})});
2270

2271
  auto three_tup = c10::ivalue::Tuple::create(
2272
      {at::randn({1}), at::randn({1}), at::randn({1})});
2273
  auto three_tup_large = c10::ivalue::Tuple::create(
2274
      {at::randn({2, 2}), at::randn({2, 2}), at::randn({2, 2})});
2275

2276
  testStaticRuntime(two_tuple_unpack_script, {two_tup});
2277
  testStaticRuntime(two_tuple_unpack_script, {two_tup}, {two_tup_large});
2278

2279
  testStaticRuntime(three_tuple_unpack_script, {three_tup});
2280
  testStaticRuntime(three_tuple_unpack_script, {three_tup}, {three_tup_large});
2281
}
2282

2283
TEST(StaticRuntime, Append) {
2284
  const auto append_int_script = R"JIT(
2285
    def forward(self, a: int):
2286
        lst = [1, 2, 3]
2287
        lst.append(a)
2288
        return lst
2289
  )JIT";
2290

2291
  const auto append_tensor_script = R"JIT(
2292
    def forward(self, a: Tensor):
2293
        lst = []
2294
        lst.append(a)
2295
        return lst
2296
  )JIT";
2297

2298
  std::vector<IValue> args_int{1};
2299

2300
  testStaticRuntime(append_int_script, args_int);
2301

2302
  std::vector<IValue> args_tensor{at::randn({1})};
2303
  std::vector<IValue> args_tensor_large{at::randn({2, 2})};
2304

2305
  testStaticRuntime(append_tensor_script, args_tensor);
2306
  testStaticRuntime(append_tensor_script, args_tensor, args_tensor_large);
2307
}
2308

2309
TEST(StaticRuntime, QuantizedLinear) {
2310
  const std::string quantize_script = R"IR(
2311
    graph(%input: Tensor, %weights: Tensor):
2312
        %scale: float = prim::Constant[value=1.]()
2313
        %zero_point: int = prim::Constant[value=1]()
2314
        %bias: None = prim::Constant()
2315
        %packed_params = quantized::linear_prepack(%weights, %bias)
2316
        %1254 = quantized::linear(%input, %packed_params, %scale, %zero_point)
2317
        %1249: Tensor = aten::dequantize(%1254)
2318
        return (%1249)
2319
  )IR";
2320
  at::Tensor weight =
2321
      at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQInt8);
2322
  at::Tensor input =
2323
      at::quantize_per_tensor(torch::randn({3, 2}), 2, 3, torch::kQUInt8);
2324

2325
  at::Tensor weight_2 =
2326
      at::quantize_per_tensor(torch::randn({8, 3}), 2, 3, torch::kQInt8);
2327
  at::Tensor input_2 =
2328
      at::quantize_per_tensor(torch::randn({9, 3}), 2, 3, torch::kQUInt8);
2329

2330
  testStaticRuntime(quantize_script, {input, weight}, {input_2, weight_2});
2331
}
2332

2333
TEST(StaticRuntime, QuantizedLinearDynamicFp16) {
2334
  const std::string quantized_linear_dynamic_fp16_script = R"IR(
2335
    graph(%input: Tensor, %weights: Tensor):
2336
        %bias: None = prim::Constant()
2337
        %packed_params = quantized::linear_prepack_fp16(%weights, %bias)
2338
        %output = quantized::linear_dynamic_fp16(%input, %packed_params)
2339
        %ret = aten::clone(%output, %bias)
2340
        return (%ret)
2341
  )IR";
2342
  at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
2343
  at::Tensor input = torch::randn({3, 2}, torch::kFloat);
2344

2345
  at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
2346
  at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
2347

2348
  testStaticRuntime(
2349
      quantized_linear_dynamic_fp16_script,
2350
      {input, weight},
2351
      {input_2, weight_2});
2352
}
2353

2354
TEST(StaticRuntime, QuantizedLinearReluDynamicFp16) {
2355
  const std::string quantized_linear_relu_dynamic_fp16_script = R"IR(
2356
    graph(%input: Tensor, %weights: Tensor):
2357
        %bias: None = prim::Constant()
2358
        %packed_params = quantized::linear_prepack_fp16(%weights, %bias)
2359
        %output = quantized::linear_relu_dynamic_fp16(%input, %packed_params)
2360
        %ret = aten::clone(%output, %bias)
2361
        return (%ret)
2362
  )IR";
2363
  at::Tensor weight = torch::randn({3, 2}, torch::kFloat);
2364
  at::Tensor input = torch::randn({3, 2}, torch::kFloat);
2365

2366
  at::Tensor weight_2 = torch::randn({4, 3}, torch::kFloat);
2367
  at::Tensor input_2 = torch::randn({5, 3}, torch::kFloat);
2368

2369
  testStaticRuntime(
2370
      quantized_linear_relu_dynamic_fp16_script,
2371
      {input, weight},
2372
      {input_2, weight_2});
2373
}
2374

2375
TEST(StaticRuntime, VarStack) {
2376
  const auto var_stack_script = R"JIT(
2377
    def forward(self, inp1: Tensor, inp2: Tensor, dim: int):
2378
        return torch.stack([inp1, inp2], dim).clone()
2379
  )JIT";
2380

2381
  // 2D tensors - stack dim = 0
2382
  std::vector<IValue> args1 = {at::randn({6, 6}), at::randn({6, 6}), 0};
2383
  testStaticRuntime(var_stack_script, args1);
2384

2385
  // 3D tensors - stack dim = 1
2386
  std::vector<IValue> args2 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 1};
2387
  testStaticRuntime(var_stack_script, args2);
2388

2389
  // 3D tensors - stack dim = 2
2390
  std::vector<IValue> args3 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), 2};
2391
  testStaticRuntime(var_stack_script, args3);
2392

2393
  // Negative dim
2394
  std::vector<IValue> args4 = {at::randn({4, 5, 6}), at::randn({4, 5, 6}), -1};
2395
  testStaticRuntime(var_stack_script, args4);
2396

2397
  // Non-serial path
2398
  std::vector<IValue> args5 = {at::randn({1, 2, 3}), at::randn({1, 2, 3}), 3};
2399
  testStaticRuntime(var_stack_script, args5);
2400

2401
  // Fast path
2402
  std::vector<IValue> args6 = {at::randn({1}), at::randn({1}), 0};
2403
  testStaticRuntime(var_stack_script, args6);
2404

2405
  testStaticRuntime(var_stack_script, args1, args2);
2406
}
2407

2408
TEST(StaticRuntime, FmodTensor) {
2409
  const auto fmod_tensor = R"JIT(
2410
    def forward(self, a: Tensor, b: Tensor):
2411
        return torch.fmod(a, b).clone()
2412
  )JIT";
2413

2414
  // fmod tensor version
2415
  auto a = at::randn({2, 3});
2416
  auto b = at::randn({2, 3});
2417
  std::vector<IValue> args0{a, b};
2418
  testStaticRuntime(fmod_tensor, args0);
2419

2420
  // check for dynamic shapes
2421
  auto c = at::randn({4, 3, 2});
2422
  auto d = at::randn({4, 3, 2});
2423
  std::vector<IValue> args1{c, d};
2424
  testStaticRuntime(fmod_tensor, args0, args1);
2425
}
2426

2427
TEST(StaticRuntime, FmodScalar) {
2428
  const auto fmod_scalar = R"JIT(
2429
    def forward(self, a: Tensor, b: int):
2430
        return torch.fmod(a, b).clone()
2431
  )JIT";
2432

2433
  auto a = at::randn({2, 3});
2434

2435
  // fmod scalar version
2436
  std::vector<IValue> args2{a, 3};
2437
  testStaticRuntime(fmod_scalar, args2);
2438

2439
  // check for dynamic shapes
2440
  auto c = at::randn({4, 3, 2});
2441
  std::vector<IValue> args3{c, 4};
2442
  testStaticRuntime(fmod_scalar, args2, args3);
2443

2444
  // test int32 version
2445
  a = at::randint(-100, 100, {2, 3}, at::kInt);
2446
  c = at::randint(-100, 100, {4, 3, 2}, at::kInt);
2447
  testStaticRuntime(fmod_scalar, {a, 3});
2448
  testStaticRuntime(fmod_scalar, {a, 3}, {c, 4});
2449
}
2450

2451
TEST(StaticRuntime, QEmbeddingBagBytePrepack) {
2452
  const std::string embedding_bag_byte_prepack_script = R"IR(
2453
    graph(%input: Tensor):
2454
        %none : None = prim::Constant()
2455
        %output: Tensor = quantized::embedding_bag_byte_prepack(%input)
2456
        %res: Tensor = aten::clone(%output, %none)
2457
        return (%res)
2458
  )IR";
2459

2460
  auto a = torch::randn({8, 16}, at::ScalarType::Float);
2461
  auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float);
2462

2463
  testStaticRuntime(embedding_bag_byte_prepack_script, {a});
2464
  testStaticRuntime(embedding_bag_byte_prepack_script, {a}, {b});
2465
}
2466

2467
TEST(StaticRuntime, QEmbeddingBagByteUnpack) {
2468
  const auto src = R"IR(
2469
    graph(%input: Tensor):
2470
        %none : None = prim::Constant()
2471
        %weight: Tensor = quantized::embedding_bag_byte_prepack(%input)
2472
        %output: Tensor = quantized::embedding_bag_byte_unpack(%weight)
2473
        %res: Tensor = aten::clone(%output, %none)
2474
        return (%res)
2475
  )IR";
2476

2477
  auto a = torch::randn({8, 16}, at::ScalarType::Float);
2478
  auto b = torch::randn({8 * 2, 16 * 2}, at::ScalarType::Float);
2479

2480
  testStaticRuntime(src, {a});
2481
  testStaticRuntime(src, {a}, {b});
2482
}
2483

2484
TEST(StaticRuntime, LinalgNorm_ScalarOrd) {
2485
  const auto linalg_norm_ord_scalar = R"JIT(
2486
    def forward(self, a: Tensor, ord: int, dim: List[int], keepdim: bool, dtype: int):
2487
        return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
2488
  )JIT";
2489

2490
  auto a = at::randn({2, 3});
2491
  auto dim = std::vector<int64_t>({1});
2492
  auto dtype = at::ScalarType::Float;
2493

2494
  std::vector<IValue> args0{a, 4, dim, true, dtype};
2495
  testStaticRuntime(linalg_norm_ord_scalar, args0);
2496

2497
  auto b = at::randn({3, 2, 6});
2498
  std::vector<IValue> args1{b, 4, dim, true, dtype};
2499
  testStaticRuntime(linalg_norm_ord_scalar, args0, args1);
2500
}
2501

2502
TEST(StaticRuntime, LinalgNorm_StringOrd) {
2503
  const auto linalg_norm_ord_str = R"JIT(
2504
    def forward(self, a: Tensor, ord: str, dim: List[int], keepdim: bool, dtype: int):
2505
        return torch.linalg_norm(a, ord, dim, keepdim, dtype=dtype).clone()
2506
  )JIT";
2507

2508
  auto a = at::randn({2, 3});
2509
  auto dim = std::vector<int64_t>({0, 1});
2510
  auto dtype = at::ScalarType::Float;
2511

2512
  std::vector<IValue> args0{a, "fro", dim, true, dtype};
2513
  testStaticRuntime(linalg_norm_ord_str, args0);
2514

2515
  auto b = at::randn({3, 2, 17});
2516
  std::vector<IValue> args1{b, "fro", dim, true, dtype};
2517
  testStaticRuntime(linalg_norm_ord_str, args0, args1);
2518
}
2519

2520
TEST(StaticRuntime, Index_Put) {
2521
  const auto index_put_str = R"JIT(
2522
    def forward(self, a: Tensor, indices: Tuple[Optional[Tensor]], values: Tensor, accumulate: bool):
2523
        return torch.index_put(a, indices, values, accumulate).clone()
2524
  )JIT";
2525

2526
  auto a = at::randn({2});
2527
  auto indices_a = std::make_tuple(torch::tensor({0}, at::kLong));
2528
  auto values_a = at::randn({1});
2529

2530
  std::vector<IValue> args0{a, indices_a, values_a, false};
2531
  testStaticRuntime(index_put_str, args0);
2532

2533
  const auto index_put_non_optional_str = R"JIT(
2534
    def forward(self, a: Tensor, indices: List[Tensor], values: Tensor, accumulate: bool):
2535
        return torch.index_put(a, indices, values, accumulate).clone()
2536
  )JIT";
2537

2538
  auto indices_b = c10::List<at::Tensor>{torch::tensor({0}, at::kLong)};
2539
  std::vector<IValue> args1{a, indices_b, values_a, false};
2540
  testStaticRuntime(index_put_non_optional_str, args1);
2541

2542
  const auto index_put_list_construct = R"JIT(
2543
    def forward(self, a: Tensor, indices: Tensor, values: Tensor, accumulate: bool):
2544
        indices: List[Optional[Tensor]] = [indices]
2545
        return torch.index_put(a, indices, values, accumulate).clone()
2546
  )JIT";
2547

2548
  std::vector<IValue> args2{a, torch::tensor({0}, at::kLong), values_a, false};
2549
  testStaticRuntime(index_put_list_construct, args2);
2550
}
2551

2552
TEST(StaticRuntime, Item) {
2553
  const auto item_str = R"JIT(
2554
    def forward(self, a: Tensor):
2555
        return torch.item(a)
2556
  )JIT";
2557

2558
  auto a = at::randn({1});
2559

2560
  std::vector<IValue> args0{a};
2561
  testStaticRuntime(item_str, args0);
2562
}
2563

2564
TEST(StaticRuntime, Tensor_Split) {
2565
  const auto tensor_split_str1 = R"JIT(
2566
    def forward(self, a: Tensor, sections: int, dim: int):
2567
        return torch.tensor_split(a, sections, dim)
2568
  )JIT";
2569
  std::vector<IValue> args1{at::randn({8}), 3, 0};
2570

2571
  const auto tensor_split_str2 = R"JIT(
2572
    def forward(self, a: Tensor, sections: Tensor, dim: int):
2573
        return torch.tensor_split(a, sections, dim)
2574
  )JIT";
2575
  std::vector<IValue> args2{at::randn({8}), torch::tensor(3), 0};
2576

2577
  const auto tensor_split_str3 = R"JIT(
2578
    def forward(self, a: Tensor, indices: List[int], dim: int):
2579
        return torch.tensor_split(a, indices, dim)
2580
  )JIT";
2581
  std::vector<IValue> args3{at::randn({8}), c10::List<int64_t>({1, 6}), 0};
2582

2583
  testStaticRuntime(tensor_split_str1, args1);
2584
  testStaticRuntime(tensor_split_str2, args2);
2585
  testStaticRuntime(tensor_split_str3, args3);
2586
}
2587

2588
TEST(StaticRuntime, JIT_Aten_Cpu) {
2589
  const std::string script = R"IR(
2590
    graph(%a: Tensor):
2591
        %1 : int = prim::Constant[value=0]()
2592
        %aa: Tensor = aten::add(%a, %a, %1)
2593
        %ret: Tensor = aten::cpu(%aa)
2594
        return (%ret)
2595
  )IR";
2596

2597
  auto graph = std::make_shared<Graph>();
2598
  std::unordered_map<std::string, Value*> vmap;
2599
  vmap.reserve(0);
2600
  parseIR(script, graph.get(), vmap);
2601
  torch::jit::StaticModule smodule(graph);
2602

2603
  auto a = at::randn({2, 4});
2604
  std::vector<IValue> args0{a};
2605

2606
  testStaticRuntime(script, args0);
2607
}
2608

2609
TEST(StaticRuntime, JIT_Aten_Numel) {
2610
  const std::string script = R"IR(
2611
    graph(%a: Tensor):
2612
        %1 : int = prim::Constant[value=0]()
2613
        %aa: Tensor = aten::add(%a, %a, %1)
2614
        %ret: int = aten::numel(%aa)
2615
        return (%ret)
2616
  )IR";
2617

2618
  auto graph = std::make_shared<Graph>();
2619
  std::unordered_map<std::string, Value*> vmap;
2620
  vmap.reserve(0);
2621
  parseIR(script, graph.get(), vmap);
2622
  torch::jit::StaticModule smodule(graph);
2623

2624
  auto a = at::randn({2, 4});
2625
  std::vector<IValue> args0{a};
2626

2627
  testStaticRuntime(script, args0);
2628
}
2629

2630
TEST(StaticRuntime, JIT_Aten_List) {
2631
  const auto script_str = R"IR(
2632
    graph(%a: str):
2633
        %ret: str[] = aten::list(%a)
2634
        return (%ret)
2635
  )IR";
2636
  std::string a = "abcd";
2637
  std::vector<IValue> args0{a};
2638
  testStaticRuntime(script_str, args0);
2639

2640
  // Update the result of aten::list to ensure that a deep copy
2641
  // took place
2642
  const auto script_list = R"IR(
2643
    graph(%a : int[]):
2644
        %idx : int = prim::Constant[value=0]()
2645
        %value : int = prim::Constant[value=42]()
2646
        %res : int[] = aten::list(%a)
2647
        %updated : int[] = aten::_set_item(%res, %idx, %value)
2648
        return (%res, %a)
2649
  )IR";
2650

2651
  std::vector<IValue> args1{c10::List<int64_t>{1, 2, 3}};
2652
  testStaticRuntime(script_list, args1);
2653
}
2654

2655
TEST(StaticRuntime, JIT_Aten_Range_Length) {
2656
  const std::string script = R"IR(
2657
    graph(%lo: int, %hi: int, %step: int):
2658
        %1 : int = prim::Constant[value=0]()
2659
        %ret: int = aten::__range_length(%lo, %hi, %step)
2660
        return (%ret)
2661
  )IR";
2662

2663
  auto graph = std::make_shared<Graph>();
2664
  std::unordered_map<std::string, Value*> vmap;
2665
  vmap.reserve(0);
2666
  parseIR(script, graph.get(), vmap);
2667
  torch::jit::StaticModule smodule(graph);
2668

2669
  std::vector<IValue> args0{0, 10, 2};
2670

2671
  testStaticRuntime(script, args0);
2672
}
2673

2674
TEST(StaticRuntime, Cat) {
2675
  const std::string cat_script = R"IR(
2676
    graph(%a: Tensor, %b: Tensor, %dim: int):
2677
        %ten_list: Tensor[] = prim::ListConstruct(%a, %b)
2678
        %1 : int = prim::Constant[value=0]()
2679
        %2 : int = prim::Constant[value=1]()
2680
        %3 : int = prim::Constant[value=1]()
2681
        %ten_list2 : Tensor[] = aten::slice(%ten_list, %1, %2, %3)
2682
        %ret: Tensor = aten::cat(%ten_list2, %dim)
2683
        return (%ret)
2684
  )IR";
2685

2686
  auto graph = std::make_shared<Graph>();
2687
  std::unordered_map<std::string, Value*> vmap;
2688
  parseIR(cat_script, graph.get(), vmap);
2689
  torch::jit::StaticModule smodule(graph);
2690
  ASSERT_TRUE(getNodeWithKind(smodule, "aten::cat"));
2691

2692
  auto a = at::randn({2, 4});
2693
  auto b = at::randn({3, 4});
2694
  std::vector<IValue> args0{a, b, 0};
2695

2696
  testStaticRuntime(cat_script, args0);
2697

2698
  auto c = at::randn({3, 4});
2699
  auto d = at::randn({3, 5});
2700
  std::vector<IValue> args1{c, d, 1};
2701
  testStaticRuntime(cat_script, args0, args1);
2702

2703
  std::vector<IValue> args_dim_negative{c, d, -1};
2704
  testStaticRuntime(cat_script, args_dim_negative);
2705
}
2706

2707
TEST(StaticRuntime, Cumsum) {
2708
  const auto cumsum_script = R"JIT(
2709
    def forward(self, a: Tensor, dim: int):
2710
        return torch.cumsum(a, dim).clone()
2711
  )JIT";
2712

2713
  auto a = at::randn({2, 3});
2714
  std::vector<IValue> args0{a, 0};
2715
  testStaticRuntime(cumsum_script, args0);
2716

2717
  auto b = at::randn({3, 6});
2718
  std::vector<IValue> args1{b, 1};
2719
  testStaticRuntime(cumsum_script, args0, args1);
2720
}
2721

2722
TEST(StaticRuntime, CumsumDtype) {
2723
  const auto cumsum_script_dtype = R"JIT(
2724
    def forward(self, a: Tensor, dim: int, dtype: int):
2725
        return torch.cumsum(a, dim, dtype=dtype).clone()
2726
  )JIT";
2727

2728
  auto a = at::randn({1, 2});
2729
  auto dtype = at::ScalarType::Float;
2730
  std::vector<IValue> args0{a, 0, dtype};
2731
  testStaticRuntime(cumsum_script_dtype, args0);
2732

2733
  auto b = at::randn({3, 6});
2734
  std::vector<IValue> args1{b, 1, dtype};
2735
  testStaticRuntime(cumsum_script_dtype, args0, args1);
2736
}
2737

2738
TEST(StaticRuntime, Nonzero) {
2739
  const auto nonzero_tensor = R"JIT(
2740
    def forward(self, input: Tensor):
2741
        a = torch.nonzero(input).clone()
2742
        return (a)
2743
  )JIT";
2744

2745
  auto a = at::randint(0, 2, {2, 3});
2746
  testStaticRuntime(nonzero_tensor, {a});
2747

2748
  auto b = at::randint(0, 2, {4, 3, 2});
2749
  testStaticRuntime(nonzero_tensor, {a}, {b});
2750
}
2751

2752
TEST(StaticRuntime, SignedLog1p) {
2753
  const std::string signed_log1p_script = R"IR(
2754
    graph(%input):
2755
        %0 : Tensor = aten::sign(%input)
2756
        %1 : Tensor = aten::abs(%input)
2757
        %2 : Tensor = aten::log1p(%1)
2758
        %3 : Tensor = aten::mul(%0, %2)
2759
        %none : NoneType = prim::Constant()
2760
        %res : Tensor = aten::clone(%3, %none)
2761
        return (%res)
2762
  )IR";
2763

2764
  std::vector<IValue> args1 = {at::randn({2, 2})};
2765
  testStaticRuntime(signed_log1p_script, args1, {}, true);
2766

2767
  std::vector<IValue> args2 = {at::randn({3, 3, 3})};
2768
  testStaticRuntime(signed_log1p_script, args1, args2, true);
2769
}
2770

2771
TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithImmutableInputDict) {
2772
  const auto getitem_immutable_input_dict_script = R"JIT(
2773
    def forward(self, input: Dict[int, Tensor]):
2774
        a = input[0]
2775
        b = input[1]
2776
        c = a + b
2777
        return c.clone()
2778
  )JIT";
2779

2780
  script::Module module("module");
2781
  module.define(getitem_immutable_input_dict_script);
2782
  torch::jit::StaticModule smodule(module);
2783
  EXPECT_FALSE(hasNodeWithKind(smodule, "aten::__getitem__"));
2784
  EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::dict_unpack"));
2785

2786
  auto a = at::randn({2, 4});
2787
  auto b = at::randn({2, 4});
2788
  c10::Dict<c10::IValue, c10::IValue> dict(
2789
      c10::IntType::get(), c10::TensorType::get());
2790
  dict.insert(0, a);
2791
  dict.insert(1, b);
2792
  testStaticRuntime(getitem_immutable_input_dict_script, {dict});
2793

2794
  c10::Dict<c10::IValue, c10::IValue> dict0(
2795
      c10::IntType::get(), c10::TensorType::get());
2796
  auto a0 = at::randn({3, 4});
2797
  auto b0 = at::randn({3, 4});
2798
  dict0.insert(0, a0);
2799
  dict0.insert(1, b0);
2800
  testStaticRuntime(getitem_immutable_input_dict_script, {dict0});
2801
}
2802

2803
TEST(StaticRuntime, RemoveImmutableInputDictLookupsWithMutableInputDict) {
2804
  const auto getitem_mutable_input_dict_script = R"JIT(
2805
    def forward(self, input: Dict[int, Tensor]):
2806
        a = input[0]
2807
        input[1] = a
2808
        b = input[1]
2809
        c = a + b
2810
        return c.clone()
2811
  )JIT";
2812

2813
  script::Module module("module");
2814
  module.define(getitem_mutable_input_dict_script);
2815
  torch::jit::StaticModule smodule(module);
2816
  EXPECT_TRUE(hasNodeWithKind(smodule, "aten::__getitem__"));
2817
  EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::dict_unpack"));
2818
}
2819

2820
TEST(StaticRuntime, VarTupleUnpack) {
2821
  const auto var_tuple_unpack_script = R"JIT(
2822
    def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]):
2823
        a, b = input_0
2824
        c, d = input_1
2825
        res = a * c + b * d
2826
        return res.clone()
2827
  )JIT";
2828

2829
  script::Module module("module");
2830
  module.define(var_tuple_unpack_script);
2831
  torch::jit::StaticModule smodule(module);
2832
  EXPECT_FALSE(hasNodeWithKind(smodule, "prim::TupleUnpack"));
2833
  EXPECT_TRUE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack"));
2834

2835
  auto a = at::randn({2, 2});
2836
  auto b = at::randn({3, 3, 3});
2837
  std::vector<IValue> args1{
2838
      c10::ivalue::Tuple::create(a, a), c10::ivalue::Tuple::create(1, 2)};
2839
  std::vector<IValue> args2{
2840
      c10::ivalue::Tuple::create(b, b), c10::ivalue::Tuple::create(1, 2)};
2841

2842
  testStaticRuntime(var_tuple_unpack_script, args1);
2843
  testStaticRuntime(var_tuple_unpack_script, args1, args2);
2844
}
2845

2846
TEST(StaticRuntime, VarTupleUnpack_NotApplied) {
2847
  const auto var_tuple_unpack_not_applied_script = R"JIT(
2848
    def forward(self, input_0: Tuple[Tensor, Tensor], input_1: Tuple[int, int]):
2849
        a, b = input_0
2850
        x = a + b
2851
        c, d = input_1
2852
        res = a * c + b * d + x
2853
        return res.clone()
2854
  )JIT";
2855

2856
  script::Module module("module");
2857
  // In this script, the optimization is not applied since there is a
2858
  // computation between the TupleUnpack nodes.
2859
  module.define(var_tuple_unpack_not_applied_script);
2860
  torch::jit::StaticModule smodule(module);
2861
  EXPECT_FALSE(hasNodeWithKind(smodule, "static_runtime::VarTupleUnpack"));
2862
  EXPECT_TRUE(hasNodeWithKind(smodule, "prim::TupleUnpack"));
2863
}
2864

2865
TEST(StaticRuntime, RemainderTensor) {
2866
  const auto remainder_tensor = R"JIT(
2867
    def forward(self, x, y):
2868
        return torch.remainder(x, y).clone()
2869
  )JIT";
2870

2871
  std::vector<IValue> args1 = {
2872
      at::randint(0, 10, {2, 2}), at::randint(1, 10, {2, 2})};
2873
  std::vector<IValue> args2 = {
2874
      at::randint(0, 10, {3, 6}), at::randint(1, 10, {3, 6})};
2875

2876
  // Use allclose and equalnan since outputs may be NaN.
2877
  testStaticRuntime(
2878
      remainder_tensor,
2879
      args1,
2880
      /*args2*/ {},
2881
      /*use_alloclose*/ true,
2882
      /*use_equalnan*/ true);
2883
  testStaticRuntime(
2884
      remainder_tensor,
2885
      args1,
2886
      args2,
2887
      /*use_allclose*/ true,
2888
      /*use_equalnan*/ true);
2889
}
2890

2891
TEST(StaticRuntime, RemainderScalar) {
2892
  const auto remainder_scalar = R"JIT(
2893
    def forward(self, x, y: int):
2894
        return torch.remainder(x, y).clone()
2895
  )JIT";
2896

2897
  std::vector<IValue> args1 = {at::randint(0, 10, {2, 2}), 4};
2898
  std::vector<IValue> args2 = {at::randint(0, 10, {3, 6}), 4};
2899

2900
  // Use allclose and equalnan since outputs may be NaN.
2901
  testStaticRuntime(
2902
      remainder_scalar,
2903
      args1,
2904
      /*args2*/ {},
2905
      /*use_alloclose*/ true,
2906
      /*use_equalnan*/ true);
2907
  testStaticRuntime(
2908
      remainder_scalar,
2909
      args1,
2910
      args2,
2911
      /*use_allclose*/ true,
2912
      /*use_equalnan*/ true);
2913
}
2914

2915
TEST(StaticRuntime, Where) {
2916
  const auto where_script = R"JIT(
2917
    def forward(self, x, y):
2918
        return torch.where(x > 0, x, y).clone()
2919
  )JIT";
2920

2921
  std::vector<IValue> args1 = {at::randn({2, 2}), at::randn({2, 2})};
2922
  std::vector<IValue> args2 = {at::randn({8, 10}), at::randn({8, 10})};
2923

2924
  testStaticRuntime(where_script, args1);
2925
  testStaticRuntime(where_script, args1, args2);
2926
}
2927

2928
TEST(StaticRuntime, WhereBroadcast) {
2929
  const auto where_script = R"JIT(
2930
    def forward(self, cond_1d, x, y):
2931
        shape = [-1] + [1] * (x.dim() - 1)
2932
        cond = cond_1d.view(shape)
2933
        return torch.where(cond, x, y).clone()
2934
  )JIT";
2935

2936
  std::vector<IValue> args1 = {
2937
      at::tensor({0, 1}).to(at::kBool), at::randn({2, 2}), at::randn({2, 2})};
2938
  std::vector<IValue> args2 = {
2939
      at::tensor({1, 0, 0}).to(at::kBool),
2940
      at::randn({3, 6}),
2941
      at::randn({3, 6})};
2942

2943
  testStaticRuntime(where_script, args1);
2944
  testStaticRuntime(where_script, args1, args2);
2945
}
2946

2947
TEST(StaticRuntime, View) {
2948
  // Note that clone is not technically necessary here since this is not
2949
  // an out variant, but it suppresses warnings about only have one op
2950
  // in testStaticRuntime
2951
  const auto src = R"IR(
2952
    graph(%input : Tensor, %shape : int[]):
2953
        %none : NoneType = prim::Constant()
2954
        %view : Tensor = aten::view(%input, %shape)
2955
        %res : Tensor = aten::clone(%view, %none)
2956
        return (%res)
2957
  )IR";
2958

2959
  std::vector<IValue> args1{at::randn({2, 2}), c10::List<int64_t>(4)};
2960
  std::vector<IValue> args2{at::randn({2, 2, 2}), c10::List<int64_t>({4, 2})};
2961

2962
  testStaticRuntime(src, args1);
2963
  testStaticRuntime(src, args1, args2);
2964
}
2965

2966
TEST(StaticRuntime, Size) {
2967
  const auto src_with_dim = R"JIT(
2968
      def forward(self, x, dim: int):
2969
          return x.size(dim)
2970
  )JIT";
2971

2972
  const auto src_no_dim = R"JIT(
2973
      def forward(self, x):
2974
          return x.size()
2975
  )JIT";
2976

2977
  std::vector<IValue> args1{at::randn({1}), 0};
2978
  std::vector<IValue> args2{at::randn({1}), -1};
2979
  std::vector<IValue> args3{at::randn({2, 4}), 1};
2980
  std::vector<IValue> args_no_dim{at::randn({2, 4})};
2981

2982
  testStaticRuntime(src_with_dim, args1);
2983
  testStaticRuntime(src_with_dim, args2);
2984
  testStaticRuntime(src_with_dim, args1, args3);
2985
  testStaticRuntime(src_no_dim, args_no_dim);
2986
}
2987

2988
TEST(StaticRuntime, Squeeze) {
2989
  // Note: this is a native op, not an out variant, but clone anyways
2990
  // to silence warnings in testStaticRuntime
2991
  const auto src = R"JIT(
2992
    def forward(self, inp, dim: int):
2993
        return inp.squeeze(dim).clone()
2994
  )JIT";
2995

2996
  const auto a = at::randn({2, 2});
2997
  const auto b = at::randn({3, 2, 3});
2998

2999
  testStaticRuntime(src, {a, 0});
3000
  testStaticRuntime(src, {a, 1});
3001
  testStaticRuntime(src, {a, -1}, {b, 2});
3002
}
3003

3004
TEST(StaticRuntime, NumToTensorScalar) {
3005
  const auto num_to_tensor_ir = R"IR(
3006
    graph(%1 : int):
3007
      %2 : NoneType = prim::Constant()
3008
      %3 : Tensor = prim::NumToTensor(%1)
3009
      %4 : Tensor = aten::clone(%3, %2)
3010
      return (%4)
3011
  )IR";
3012

3013
  IValue arg{5};
3014
  std::vector<IValue> args = {arg};
3015
  testStaticRuntime(num_to_tensor_ir, args);
3016
}
3017

3018
TEST(StaticRuntime, NumToTensorFalse) {
3019
  const auto num_to_tensor_ir = R"IR(
3020
    graph(%1 : bool):
3021
      %2 : NoneType = prim::Constant()
3022
      %3 : Tensor = prim::NumToTensor(%1)
3023
      %4 : Tensor = aten::clone(%3, %2)
3024
      return (%4)
3025
  )IR";
3026

3027
  IValue arg{false};
3028
  std::vector<IValue> args = {arg};
3029
  testStaticRuntime(num_to_tensor_ir, args);
3030
}
3031

3032
TEST(StaticRuntime, NumToTensorTrue) {
3033
  const auto num_to_tensor_ir = R"IR(
3034
    graph(%1 : bool):
3035
      %2 : NoneType = prim::Constant()
3036
      %3 : Tensor = prim::NumToTensor(%1)
3037
      %4 : Tensor = aten::clone(%3, %2)
3038
      return (%4)
3039
  )IR";
3040

3041
  IValue arg{true};
3042
  std::vector<IValue> args = {arg};
3043
  testStaticRuntime(num_to_tensor_ir, args);
3044
}
3045

3046
TEST(StaticRuntime, Split) {
3047
  const auto src = R"JIT(
3048
    def forward(self, inp, split_size: int, dim: int):
3049
        return inp.split(split_size, dim)
3050
  )JIT";
3051

3052
  const auto a = at::randn({2, 2});
3053
  const auto b = at::randn({2, 2, 2});
3054

3055
  testStaticRuntime(src, {a, 1, 0});
3056
  testStaticRuntime(src, {a, 1, 1});
3057
  testStaticRuntime(src, {a, 2, -1}, {b, 2, 2});
3058
}
3059

3060
TEST(StaticRuntime, SplitWithSizes) {
3061
  const auto src = R"JIT(
3062
    def forward(self, inp, split_sizes: List[int], dim: int):
3063
        return inp.split(split_sizes, dim)
3064
  )JIT";
3065

3066
  const auto a = at::randn({2, 2});
3067
  const auto b = at::randn({2, 2, 2});
3068
  const auto split_sizes = c10::List<int64_t>{1, 1};
3069

3070
  testStaticRuntime(src, {a, split_sizes, 0});
3071
  testStaticRuntime(src, {a, split_sizes, 1});
3072
  testStaticRuntime(src, {a, split_sizes, -1}, {b, split_sizes, 2});
3073
}
3074

3075
namespace {
3076

3077
void maybe_throw(bool should_throw) {
3078
  if (should_throw) {
3079
    throw std::runtime_error("test exception");
3080
  }
3081
}
3082

3083
TORCH_LIBRARY(static_runtime_tests, m) {
3084
  // Conservative so this op doesn't get deleted by dead
3085
  // code elimination
3086
  m.def(torch::schema(
3087
      "static_runtime_tests::maybe_throw(bool throw) -> ()",
3088
      at::AliasAnalysisKind::CONSERVATIVE));
3089
  m.impl("maybe_throw", maybe_throw);
3090
}
3091

3092
} // namespace
3093

3094
TEST(StaticRuntime, ModelCrashOnFirstRun) {
3095
  const auto src = R"JIT(
3096
    graph(%0: Tensor, %throw: bool):
3097
        %1: Tensor = aten::mul(%0, %0)
3098
        static_runtime_tests::maybe_throw(%throw)
3099
        %2: Tensor = aten::mul(%1, %1)
3100
        %3: Tensor = aten::mul(%2, %2)
3101
        return (%3)
3102
  )JIT";
3103

3104
  auto graph = getGraphFromIR(src);
3105
  auto static_module = StaticModule(graph);
3106
  auto& runtime = static_module.runtime();
3107

3108
  std::vector<IValue> args_crash{at::randn({1}), true};
3109
  std::vector<IValue> args_no_crash{at::randn({1}), false};
3110
  EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);
3111

3112
  // The run didn't finish, we didn't allocate the memory planner
3113
  EXPECT_EQ(runtime.get_memory_planner(), nullptr);
3114
  runtime.check_for_memory_leak();
3115

3116
  // We guarantee that the runtime is still usable after the crash.
3117
  // Run again to verify this.
3118
  compareResultsWithJIT(runtime, graph, args_no_crash);
3119
  EXPECT_NE(runtime.get_memory_planner(), nullptr);
3120
}
3121

3122
TEST(StaticRuntime, ModelCrashOnSecondRun) {
3123
  const auto src = R"JIT(
3124
    graph(%0: Tensor, %throw: bool):
3125
        %1: Tensor = aten::mul(%0, %0)
3126
        static_runtime_tests::maybe_throw(%throw)
3127
        %2: Tensor = aten::mul(%1, %1)
3128
        %3: Tensor = aten::mul(%2, %2)
3129
        return (%3)
3130
  )JIT";
3131

3132
  auto graph = getGraphFromIR(src);
3133
  auto static_module = StaticModule(graph);
3134
  auto& runtime = static_module.runtime();
3135

3136
  std::vector<IValue> args_crash{at::randn({1}), true};
3137
  std::vector<IValue> args_no_crash{at::randn({1}), false};
3138
  runtime(args_no_crash, {});
3139
  EXPECT_NE(runtime.get_memory_planner(), nullptr);
3140
  runtime.check_for_memory_leak();
3141

3142
  EXPECT_THROW(runtime(args_crash, {}), std::runtime_error);
3143
  runtime.check_for_memory_leak();
3144

3145
  // We guarantee that the runtime is still usable after the crash.
3146
  // Run again to verify this.
3147
  compareResultsWithJIT(runtime, graph, args_no_crash);
3148
}
3149

3150
TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrows) {
3151
  const auto src = R"JIT(
3152
    graph(%0: Tensor):
3153
        %1: Tensor = aten::mul(%0, %0)
3154
        %2: Tensor = aten::mul(%1, %1)
3155
        %3: bool = prim::Constant[value=1]()
3156
        %4: Tensor = static_runtime::select_tensor(%1, %2, %3)
3157
        static_runtime_tests::maybe_throw(%3)
3158
        return (%4)
3159
  )JIT";
3160
  auto graph = getGraphFromIR(src);
3161
  auto static_module = StaticModule(graph);
3162
  auto& runtime = static_module.runtime();
3163

3164
  std::vector<IValue> args{at::randn({1})};
3165
  EXPECT_THROW(runtime(args), std::runtime_error);
3166
}
3167

3168
TEST(StaticRuntime, ModelCrashOnFirstRunWithBorrowedInputs) {
3169
  const auto src = R"JIT(
3170
    graph(%0: Tensor, %1: Tensor):
3171
        %2: bool = prim::Constant[value=1]()
3172
        %3: Tensor = static_runtime::select_tensor(%0, %1, %2)
3173
        static_runtime_tests::maybe_throw(%2)
3174
        return (%3)
3175
  )JIT";
3176
  auto graph = getGraphFromIR(src);
3177
  auto static_module = StaticModule(graph);
3178
  auto& runtime = static_module.runtime();
3179

3180
  std::vector<IValue> args{at::randn({1}), at::randn({1})};
3181
  EXPECT_THROW(runtime(std::move(args)), std::runtime_error);
3182
}
3183

3184
TEST(StaticRuntime, ReplaceWithMaybeCopy) {
3185
  const std::string to = R"IR(
3186
    graph(%0 : Tensor):
3187
      %1: int = prim::Constant[value=4]()
3188
      %2: bool = prim::Constant[value=0]()
3189
      %3: None = prim::Constant()
3190
      %res : Tensor = aten::to(%0, %1, %2, %2, %3)
3191
      return (%res)
3192
  )IR";
3193

3194
  at::Tensor a = at::tensor({1.1, 2.2, 3.3, 4.0}, at::ScalarType::Float);
3195
  std::vector<IValue> args{a};
3196
  auto g = std::make_shared<torch::jit::Graph>();
3197
  torch::jit::parseIR(to, g.get());
3198

3199
  // Jit Interpreter.
3200
  Stack stack(args);
3201
  torch::jit::GraphExecutor graph_exec(g, "");
3202
  graph_exec.run(stack);
3203
  ASSERT_EQ(stack.size(), 1);
3204
  auto expected = stack[0].toTensor();
3205

3206
  // Static Runtime.
3207
  torch::jit::StaticModule smodule(g);
3208
  auto actual = smodule(args, {}).toTensor();
3209
  smodule.runtime().check_for_memory_leak();
3210

3211
  EXPECT_TRUE(expected.equal(actual));
3212

3213
  // Make a fresh graph to ensure the pass works in isolation
3214
  auto new_graph = std::make_shared<torch::jit::Graph>();
3215
  torch::jit::parseIR(to, new_graph.get());
3216
  ReplaceWithMaybeCopy(new_graph);
3217
  EXPECT_FALSE(hasNodeWithKind(new_graph, "aten::to"));
3218
  EXPECT_TRUE(
3219
      hasNodeWithKind(new_graph, "static_runtime::to_maybe_copy_out"));
3220
}
3221

3222
TEST(StaticRuntime, Int) {
3223
  const auto src = R"JIT(
3224
    def forward(self, x):
3225
        return int(x) + int(x)
3226
  )JIT";
3227
  std::vector<IValue> args{at::tensor({3.14})};
3228
  testStaticRuntime(src, args);
3229
}
3230

3231
TEST(StaticRuntime, ReturnConstant) {
3232
  const auto src = R"JIT(
3233
    def forward(self):
3234
        return 1
3235
  )JIT";
3236

3237
  testStaticRuntime(src, {});
3238
}
3239

3240
TEST(StaticRuntime, SimpleIf) {
3241
  const auto src = R"JIT(
3242
    def forward(self, cond: bool, x):
3243
        if cond:
3244
            return torch.mul(x, 42).clone()
3245
        else:
3246
            return x.clone()
3247
  )JIT";
3248

3249
  std::vector<IValue> args_false{false, at::randn({1})};
3250
  std::vector<IValue> args_true{true, at::randn({1})};
3251
  std::vector<IValue> args_big_tensor{true, at::randn({3, 3, 3})};
3252

3253
  testStaticRuntime(src, args_false);
3254
  testStaticRuntime(src, args_true);
3255
  testStaticRuntime(src, args_true, args_big_tensor);
3256
}
3257

3258
TEST(StaticRuntime, NestedIf) {
3259
  const auto src = R"JIT(
3260
    def forward(self, cond1: bool, cond2: bool, x):
3261
        y = x * 42
3262
        if cond1:
3263
            y = y * y
3264
            if cond2:
3265
                y += x
3266
        else:
3267
            if cond2:
3268
                return x.clone()
3269

3270
        return y.clone()
3271
  )JIT";
3272

3273
  for (auto cond1 : {true, false}) {
3274
    for (auto cond2 : {true, false}) {
3275
      std::vector<IValue> args1{cond1, cond2, at::randn({1})};
3276
      std::vector<IValue> args2{cond1, cond2, at::randn({3, 3, 3})};
3277
      testStaticRuntime(src, args1, args2);
3278
    }
3279
  }
3280
}
3281

3282
TEST(StaticRuntime, DeeplyNestedIf) {
3283
  const auto src = R"JIT(
3284
    def forward(self, cond1: bool, cond2: bool, cond3: bool, x):
3285
        y = x * 42
3286
        if cond1:
3287
            y = y * y
3288
            if cond2:
3289
                y += x
3290

3291
            if cond2 and cond3:
3292
                y += 1
3293

3294
            if cond2:
3295
                if cond3:
3296
                    y += 2
3297
                else:
3298
                    y = y * y
3299
                    y += 4
3300
        else:
3301
            if cond2:
3302
                return x.clone()
3303
            if cond3 or cond2:
3304
                y += 42
3305

3306
        return y.clone()
3307
  )JIT";
3308

3309
  for (auto cond1 : {true, false}) {
3310
    for (auto cond2 : {true, false}) {
3311
      for (auto cond3 : {true, false}) {
3312
        std::vector<IValue> args1{cond1, cond2, cond3, at::randn({1})};
3313
        std::vector<IValue> args2{cond1, cond2, cond3, at::randn({3, 3, 3})};
3314
        testStaticRuntime(src, args1, args2);
3315
      }
3316
    }
3317
  }
3318
}
3319

3320
TEST(StaticRuntime, BasicForLoop) {
3321
  const auto src = R"JIT(
3322
    def forward(self, x, loop_max: int):
3323
        y = x.clone()
3324
        for i in range(loop_max):
3325
            y += 1
3326
        return y
3327
  )JIT";
3328

3329
  std::vector<IValue> args1{at::randn({1}), 10};
3330
  std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
3331

3332
  testStaticRuntime(src, args1, args2);
3333
}
3334

3335
TEST(StaticRuntime, BasicWhileLoop) {
3336
  const auto src = R"JIT(
3337
    def forward(self, x, loop_max: int):
3338
        y = x.clone()
3339
        loop_count = 0
3340
        while loop_count < loop_max:
3341
            y += 1
3342
            loop_count += 1
3343
        return y
3344
  )JIT";
3345

3346
  std::vector<IValue> args1{at::randn({1}), 10};
3347
  std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
3348

3349
  testStaticRuntime(src, args1, args2);
3350
}
3351

3352
TEST(StaticRuntime, NestedLoops) {
3353
  const auto src = R"JIT(
3354
    def forward(self, x, loop_max: int):
3355
        y = x.clone()
3356
        even: List[int] = []
3357
        odd: List[int] = []
3358

3359
        for i in range(loop_max):
3360
            if i % 2:
3361
                odd.append(i)
3362
            else:
3363
                even.append(i)
3364

3365
            for j in range(i):
3366
                y += 1
3367

3368
        return y, even, odd
3369
  )JIT";
3370

3371
  std::vector<IValue> args1{at::randn({1}), 10};
3372
  std::vector<IValue> args2{at::randn({3, 3, 3}), 10};
3373

3374
  testStaticRuntime(src, args1, args2);
3375
}
3376

3377
TEST(StaticRuntime, TupleIndex) {
3378
  const auto src = R"JIT(
3379
    def forward(self, idx: int, tup: Tuple[int, int]):
3380
        a = tup[idx]
3381
        return a * a
3382
  )JIT";
3383
  const auto tuple = c10::ivalue::Tuple::create({1, 2});
3384
  testStaticRuntime(src, {1, tuple}, {-1, tuple});
3385

3386
  torch::jit::Module mod("module");
3387
  mod.define(src);
3388
  StaticModule smod(mod);
3389
  EXPECT_THROW(smod({100, tuple}), std::out_of_range);
3390
}
3391

3392
TEST(StaticRuntime, RaiseException) {
3393
  const auto src = R"IR(
3394
    graph(%str: str):
3395
        %none: NoneType = prim::Constant()
3396
        prim::RaiseException(%str, %none)
3397
        return (%none)
3398
  )IR";
3399
  auto graph = getGraphFromIR(src);
3400
  StaticModule smod(graph);
3401
  const auto msg = "exception message";
3402
  EXPECT_THROW(
3403
      {
3404
        try {
3405
          smod({msg});
3406
        } catch (const std::runtime_error& e) {
3407
          EXPECT_STREQ(msg, e.what());
3408
          throw;
3409
        }
3410
      },
3411
      std::runtime_error);
3412
}
3413

3414
TEST(StaticRuntime, Uninitialized) {
3415
  const auto src = R"IR(
3416
    graph():
3417
      %0: int = prim::Uninitialized()
3418
      return (%0)
3419
  )IR";
3420
  auto graph = getGraphFromIR(src);
3421
  StaticModule smod(graph);
3422
  const auto ret = smod({});
3423
  // If a and b are both uninitialized, then a != b. So just check that the type
3424
  // is Any
3425
  EXPECT_EQ(ret.type()->kind(), c10::TypeKind::AnyType);
3426
}
3427

3428
TEST(StaticRuntime, Format) {
3429
  const auto src = R"JIT(
3430
    def forward(self, arg1: int, arg2: Tensor, arg3: str):
3431
        a = "arg1: {}, arg2: {}, arg3: {}".format(arg1, arg2, arg3)
3432
        return a[::]
3433
  )JIT";
3434
  testStaticRuntime(src, {1, at::randn({3}), "str"});
3435
}
3436

3437
TEST(StaticRuntime, Device) {
3438
  const auto src = R"JIT(
3439
    def forward(self, x):
3440
        return x.device, x.device
3441
  )JIT";
3442
  testStaticRuntime(src, {at::tensor({1})});
3443
}
3444

3445
TEST(StaticRuntime, Dtype) {
3446
  const auto src = R"JIT(
3447
    def forward(self, x, y):
3448
        return x.dtype, y.dtype
3449
  )JIT";
3450
  testStaticRuntime(
3451
      src, {at::tensor({1}, at::kLong), at::tensor({1}, at::kFloat)});
3452
}
3453

3454
TEST(StaticRuntime, Dim) {
3455
  const auto src = R"JIT(
3456
    def forward(self, x, y):
3457
        return x.dim(), y.dim()
3458
  )JIT";
3459
  testStaticRuntime(src, {at::randn({2, 2}), at::randn({1})});
3460
}
3461

3462
TEST(StaticRuntime, Not) {
3463
  const auto src = R"JIT(
3464
    def forward(self, x: bool, y: bool):
3465
        return not x, not y
3466
  )JIT";
3467
  testStaticRuntime(src, {true, false});
3468
}
3469

3470
TEST(StaticRuntime, Bool) {
3471
  const auto src = R"JIT(
3472
      def forward(self, x: Tensor, y: int, z: float):
3473
          return bool(x), bool(y), bool(z)
3474
  )JIT";
3475
  testStaticRuntime(src, {at::randn({1}), 0, 1.151}, {at::zeros({1}), 1, 0.0});
3476
}
3477

3478
TEST(StaticRuntime, IsCuda) {
3479
  const auto src = R"JIT(
3480
      def forward(self, x: Tensor, y: Tensor):
3481
          return x.is_cuda, y.is_cuda
3482
  )JIT";
3483
  testStaticRuntime(src, {at::randn({1}), at::randn({1})});
3484
}
3485

3486
TEST(StaticRuntime, ToList) {
3487
  const auto src = R"JIT(
3488
      graph(%x: Tensor):
3489
          %type: int = prim::Constant[value=1]()
3490
          %dim: int = aten::dim(%x)
3491
          %ret: float[] = prim::tolist(%x, %dim, %type)
3492
          return (%ret)
3493
  )JIT";
3494
  testStaticRuntime(src, {at::randn({2, 2})});
3495
}
3496

3497
TEST(StaticRuntime, IfThenElse) {
3498
  const auto src = R"IR(
3499
    graph(%cond: bool, %a: Tensor, %b: Tensor):
3500
        %none: NoneType = prim::Constant()
3501
        %c: Tensor = prim::IfThenElse(%cond, %a, %b)
3502
        %d: Tensor = aten::clone(%c, %none)
3503
        return (%d)
3504
  )IR";
3505

3506
  std::vector<IValue> args1{true, at::randn({1}), at::randn({1})};
3507
  std::vector<IValue> args2{false, at::randn({1}), at::randn({1})};
3508

3509
  testStaticRuntime(src, args1);
3510
  testStaticRuntime(src, args2);
3511
}
3512

3513
TEST(StaticRuntime, EmptyIfBlock) {
3514
  const auto src =
3515
      R"JIT(
3516
      def forward(self, cond: bool, a: Tensor, b: Tensor):
3517
          l = []
3518
          if cond:
3519
              l.append((a + b).clone())
3520
          return l
3521
  )JIT";
3522

3523
  testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})});
3524
  testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})});
3525
}
3526

3527
TEST(StaticRuntime, EmptyNestedIfBlock) {
3528
  const auto src =
3529
      R"JIT(
3530
      def forward(self, cond: bool, a: Tensor, b: Tensor):
3531
          l = []
3532
          if cond:
3533
              if cond:
3534
                  l.append((a + b).clone())
3535
          return l
3536
  )JIT";
3537

3538
  testStaticRuntime(src, {true, at::rand(1), at::rand({1, 2})});
3539
  testStaticRuntime(src, {false, at::rand(1), at::rand({1, 2})});
3540
}
3541

3542
TEST(StaticRuntime, StackEmpty) {
3543
  const auto src = R"JIT(
3544
    def forward(self):
3545
        x = torch.stack([])
3546
        return x
3547
  )JIT";
3548

3549
  torch::jit::Module mod("mod");
3550
  mod.define(src);
3551

3552
  torch::jit::StaticModule smod(mod);
3553
  EXPECT_THROW(smod({}), c10::Error);
3554
}
3555

3556
TEST(StaticRuntime, ConcatEmpty) {
3557
  const auto src = R"JIT(
3558
    def forward(self):
3559
        x = torch.concat([])
3560
        return x
3561
  )JIT";
3562

3563
  torch::jit::Module mod("mod");
3564
  mod.define(src);
3565

3566
  torch::jit::StaticModule smod(mod);
3567
  EXPECT_THROW(smod({}), c10::Error);
3568
}
3569

3570
TEST(StaticRuntime, IntImplicit) {
3571
  const auto src = R"IR(
3572
    graph(%a: Tensor):
3573
        %y: int = aten::IntImplicit(%a)
3574
        return (%y)
3575
  )IR";
3576
  testStaticRuntime(src, {at::tensor({1}, at::kInt).squeeze()});
3577
}
3578

3579
TEST(StaticRuntime, IntImplicit_ThrowOnBadInputs) {
3580
  const auto src = R"IR(
3581
    graph(%a: Tensor):
3582
        %y: int = aten::IntImplicit(%a)
3583
        return (%y)
3584
  )IR";
3585
  auto graph = getGraphFromIR(src);
3586
  torch::jit::StaticModule smod(graph);
3587
  // Not 0D tensor
3588
  EXPECT_THROW(smod({at::tensor({1, 2}, at::kInt)}), std::runtime_error);
3589
  // Wrong dtype
3590
  EXPECT_THROW(
3591
      smod({at::tensor({1}, at::kFloat).squeeze()}), std::runtime_error);
3592
}
3593

3594
TEST(StaticRuntime, Select) {
3595
  const auto src = R"IR(
3596
    graph(%a: Tensor, %dim: int, %index: int):
3597
        %none: NoneType = prim::Constant()
3598
        %b: Tensor = aten::select(%a, %dim, %index)
3599
        %c: Tensor = aten::clone(%b, %none)
3600
        return (%c)
3601
  )IR";
3602
  testStaticRuntime(src, {at::randn({2, 2}), 0, 1});
3603
}
3604

3605
TEST(StaticRuntime, ReshapeAs) {
3606
  const auto src = R"JIT(
3607
    def forward(self, a, b):
3608
        return a.reshape_as(b).clone()
3609
  )JIT";
3610
  testStaticRuntime(src, {at::randn({2, 2}), at::randn({4})});
3611
}
3612

3613
TEST(StaticRuntime, MoveCtor) {
3614
  auto mod = getDeepAndWideSciptModel();
3615
  std::vector<IValue> args{
3616
      at::randn({1, 1, 32}), at::randn({1, 1, 32}), at::randn({1, 50})};
3617

3618
  torch::jit::StaticModule smod(mod);
3619

3620
  torch::jit::StaticRuntime runtime(smod);
3621
  auto expected = runtime(args);
3622

3623
  torch::jit::StaticRuntime new_runtime(std::move(runtime));
3624
  auto actual = new_runtime(args);
3625
  compareResults(expected, actual);
3626
}
3627

3628
TEST(StaticRuntime, SingleBlockIfReturnList) {
3629
  const auto src = R"JIT(
3630
    def forward(self, a, b, cond: bool):
3631
        lst = []
3632
        if cond:
3633
            lst.append(a + b)
3634
        return lst
3635
  )JIT";
3636
  std::vector<IValue> args1{at::randn({1}), at::randn({1}), true};
3637
  std::vector<IValue> args2{at::randn({42, 42}), at::randn({42, 42}), false};
3638
  testStaticRuntime(src, args1, args2);
3639
}
3640

3641
TEST(StaticRuntime, NestedBlockIfReturnList) {
3642
  const auto src = R"JIT(
3643
    def forward(self, a, b, cond1: bool, cond2: bool):
3644
        if cond1:
3645
            lst = []
3646
            if cond2:
3647
                lst.append(a + b)
3648
            lst.append(a * b)
3649
            return lst
3650
        return []
3651
  )JIT";
3652
  std::vector<IValue> args1{at::randn({1}), at::randn({1}), true, true};
3653
  std::vector<IValue> args2{
3654
      at::randn({42, 42}), at::randn({42, 42}), true, false};
3655
  testStaticRuntime(src, args1, args2);
3656
}
3657

3658
TEST(StaticRuntime, ClampNaNToNum) {
3659
  const auto src1 = R"JIT(
3660
    def forward(self, a):
3661
        return torch.clamp(a, min=1.0, max=2.0).nan_to_num().clone()
3662
  )JIT";
3663

3664
  const auto src2 = R"JIT(
3665
    def forward(self, a, nan: float):
3666
        return torch.clamp(a, min=-1.0, max=2.0).nan_to_num(nan=nan).clone()
3667
  )JIT";
3668

3669
  const auto src3 = R"JIT(
3670
    def forward(self, a):
3671
        return torch.clamp(a, min=1.0, max=-1.0).nan_to_num().clone()
3672
  )JIT";
3673

3674
  auto a = at::tensor({
3675
      std::numeric_limits<float>::quiet_NaN(),
3676
      std::numeric_limits<float>::infinity(),
3677
      -std::numeric_limits<float>::infinity(),
3678
      0.0f,
3679
      3.0f
3680
    });
3681
  auto b = a.repeat({10, 5});
3682

3683
  // Have to use_allclose even though all NaNs will be replaced - testStaticRuntime
3684
  // also checks inputs at the end to make sure they're not changed
3685
  testStaticRuntime(src1, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3686
  testStaticRuntime(src1, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
3687

3688
  testStaticRuntime(src2, {a, 42.0}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3689
  testStaticRuntime(src2, {a, 2.0}, {b, 1.0}, /*use_allclose=*/true, /*use_equalnan=*/true);
3690

3691
  testStaticRuntime(src3, {a}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3692
  testStaticRuntime(src3, {a}, {b}, /*use_allclose=*/true, /*use_equalnan=*/true);
3693

3694
  // Non-NNC path
3695
  testStaticRuntime(src1, {a.to(at::kDouble)}, {}, /*use_allclose=*/true, /*use_equalnan=*/true);
3696
  testStaticRuntime(src1, {a.to(at::kDouble)}, {b.to(at::kDouble)}, /*use_allclose=*/true, /*use_equalnan=*/true);
3697
}
3698

3699
TEST(StaticRuntime, IfReturningTuple) {
3700
  const auto src = R"JIT(
3701
    def forward(self, x, y, cond: bool, idx: int):
3702
        if cond:
3703
            tup = (x, y)
3704
        else:
3705
            tup = (x, x)
3706
        return tup[idx]
3707
  )JIT";
3708

3709
  std::vector<IValue> args{at::randn({3}), at::randn({3}), true, 0};
3710
  testStaticRuntime(src, args);
3711
}
3712

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.