pytorch

Форк
0
/
test_external_calls.cpp 
1058 строк · 36.7 Кб
1
#include <gtest/gtest.h>
2

3
#include <test/cpp/tensorexpr/test_base.h>
4

5
#include <torch/csrc/jit/ir/irparser.h>
6
#include <torch/csrc/jit/passes/subgraph_rewrite.h>
7
#include <torch/csrc/jit/passes/tensorexpr_fuser.h>
8
#include <torch/csrc/jit/runtime/custom_operator.h>
9
#include <torch/csrc/jit/tensorexpr/kernel.h>
10

11
#include <test/cpp/tensorexpr/test_utils.h>
12
#include <torch/csrc/jit/runtime/operator.h>
13
#include <torch/csrc/jit/runtime/symbolic_shape_registry.h>
14
#include <torch/csrc/jit/tensorexpr/eval.h>
15
#include <torch/csrc/jit/tensorexpr/external_functions_registry.h>
16
#include <torch/csrc/jit/tensorexpr/ir.h>
17
#include <torch/csrc/jit/tensorexpr/ir_printer.h>
18
#include <torch/csrc/jit/tensorexpr/ir_simplifier.h>
19
#include <torch/csrc/jit/tensorexpr/llvm_codegen.h>
20
#include <torch/csrc/jit/tensorexpr/loopnest.h>
21
#include <torch/csrc/jit/tensorexpr/tensor.h>
22

23
#include <torch/csrc/jit/testing/file_check.h>
24
#include <torch/jit.h>
25

26
#include <ATen/NativeFunctions.h>
27
#include <ATen/core/dispatch/Dispatcher.h>
28
#include <ATen/native/xnnpack/OpContext.h>
29

30
namespace torch {
31
namespace jit {
32
using namespace torch::jit::tensorexpr;
33

34
TEST(ExternalCall, Conv1d_float) {
35
  BufHandle Input("Input", {1, 100, 115}, kFloat);
36
  BufHandle Weight("Weight", {100, 1, 7}, kFloat);
37
  BufHandle Bias("Bias", {100}, kFloat);
38
  BufHandle ResultBuf("Result", {1, 100, 115}, kFloat);
39
  int64_t stride = 1;
40
  int64_t pad = 3;
41
  int64_t dilation = 1;
42
  int64_t groups = 100;
43

44
  Tensor Result = Tensor(
45
      ResultBuf.node(),
46
      ExternalCall::make(
47
          ResultBuf,
48
          "nnc_aten_conv1d",
49
          {Input, Weight, Bias},
50
          {stride, pad, dilation, groups}));
51
  LoopNest l({Result});
52
  l.prepareForCodegen();
53
  l.simplify();
54

55
  auto options = at::TensorOptions()
56
                     .dtype(at::kFloat)
57
                     .layout(at::kStrided)
58
                     .device(at::kCPU)
59
                     .requires_grad(false);
60
  at::Tensor input = at::ones({1, 100, 115}, options) * 5.f;
61
  at::Tensor weight = at::ones({100, 1, 7}, options) * 6.f;
62
  at::Tensor bias = at::ones({100}, options) * 11.f;
63
  at::Tensor ref =
64
      at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups);
65

66
  at::Tensor nnc_result;
67
  std::vector<float> input_buf(1 * 100 * 115, 5.f);
68
  std::vector<float> weight_buf(100 * 1 * 7, 6.f);
69
  std::vector<float> bias_buf(100, 11.f);
70
  std::vector<float> result_buf(1 * 100 * 115, -1.f);
71

72
#ifdef TORCH_ENABLE_LLVM
73
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
74

75
  llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
76
  nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
77
  ASSERT_TRUE(at::allclose(nnc_result, ref));
78
#endif
79

80
  SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
81

82
  ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
83
  nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
84
  ASSERT_TRUE(at::allclose(nnc_result, ref));
85
}
86

87
TEST(ExternalCall, Conv1d_int) {
88
  // A similar test, but now using kInt tensors
89
  BufHandle Input("Input", {1, 100, 115}, kInt);
90
  BufHandle Weight("Weight", {100, 1, 7}, kInt);
91
  BufHandle Bias("Bias", {100}, kInt);
92
  BufHandle ResultBuf("Result", {1, 100, 115}, kInt);
93
  int64_t stride = 1;
94
  int64_t pad = 3;
95
  int64_t dilation = 1;
96
  int64_t groups = 100;
97

98
  Tensor Result = Tensor(
99
      ResultBuf.node(),
100
      ExternalCall::make(
101
          ResultBuf,
102
          "nnc_aten_conv1d",
103
          {Input, Weight, Bias},
104
          {stride, pad, dilation, groups}));
105
  LoopNest l({Result});
106
  l.prepareForCodegen();
107
  l.simplify();
108

109
  auto options = at::TensorOptions()
110
                     .dtype(at::kInt)
111
                     .layout(at::kStrided)
112
                     .device(at::kCPU)
113
                     .requires_grad(false);
114
  at::Tensor input = at::ones({1, 100, 115}, options) * 5;
115
  at::Tensor weight = at::ones({100, 1, 7}, options) * 6;
116
  at::Tensor bias = at::ones({100}, options) * 11;
117
  at::Tensor ref =
118
      at::conv1d(input, weight, bias, {stride}, {pad}, {dilation}, groups);
119

120
  at::Tensor nnc_result;
121
  std::vector<int32_t> input_buf(1 * 100 * 115, 5);
122
  std::vector<int32_t> weight_buf(100 * 1 * 7, 6);
123
  std::vector<int32_t> bias_buf(100, 11);
124
  std::vector<int32_t> result_buf(1 * 100 * 115, -1);
125

126
#ifdef TORCH_ENABLE_LLVM
127
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
128

129
  llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
130
  nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
131
  ASSERT_TRUE(at::allclose(nnc_result, ref));
132
#endif
133

134
  SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
135

136
  ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
137
  nnc_result = at::from_blob(result_buf.data(), {1, 100, 115}, options);
138
  ASSERT_TRUE(at::allclose(nnc_result, ref));
139
}
140

141
TEST(ExternalCall, Conv1d_nobias_noargs) {
142
  BufHandle Input("Input", {1, 1, 115}, kFloat);
143
  BufHandle Weight("Weight", {10, 1, 7}, kFloat);
144
  BufHandle ResultBuf("Result", {1, 10, 109}, kFloat);
145

146
  Tensor Result = Tensor(
147
      ResultBuf.node(),
148
      ExternalCall::make(ResultBuf, "nnc_aten_conv1d", {Input, Weight}, {}));
149
  LoopNest l({Result});
150
  l.prepareForCodegen();
151
  l.simplify();
152

153
  auto options = at::TensorOptions()
154
                     .dtype(at::kFloat)
155
                     .layout(at::kStrided)
156
                     .device(at::kCPU)
157
                     .requires_grad(false);
158
  at::Tensor input = at::ones({1, 1, 115}, options) * 5.f;
159
  at::Tensor weight = at::ones({10, 1, 7}, options) * 6.f;
160
  at::Tensor ref = at::conv1d(input, weight);
161

162
  at::Tensor nnc_result;
163
  std::vector<float> input_buf(1 * 1 * 115, 5.f);
164
  std::vector<float> weight_buf(10 * 1 * 7, 6.f);
165
  std::vector<float> result_buf(1 * 10 * 109, -1.f);
166

167
#ifdef TORCH_ENABLE_LLVM
168
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result});
169

170
  llvm_codegen.call({input_buf, weight_buf, result_buf});
171
  nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options);
172
  ASSERT_TRUE(at::allclose(nnc_result, ref));
173
#endif
174

175
  SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result});
176

177
  ir_eval.call({input_buf, weight_buf, result_buf});
178
  nnc_result = at::from_blob(result_buf.data(), {1, 10, 109}, options);
179
  ASSERT_TRUE(at::allclose(nnc_result, ref));
180
}
181

182
TEST(ExternalCall, Conv2d_float) {
183
  BufHandle Input("Input", {1, 3, 224, 224}, kFloat);
184
  BufHandle Weight("Weight", {16, 3, 3, 3}, kFloat);
185
  BufHandle Bias("Bias", {16}, kFloat);
186
  BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
187
  int64_t stride = 2;
188
  int64_t pad = 1;
189
  int64_t dilation = 1;
190
  int64_t groups = 1;
191

192
  Tensor Result = Tensor(
193
      ResultBuf.node(),
194
      ExternalCall::make(
195
          ResultBuf,
196
          "nnc_aten_conv2d",
197
          {Input, Weight, Bias},
198
          {stride, stride, pad, pad, dilation, dilation, groups}));
199
  LoopNest l({Result});
200
  l.prepareForCodegen();
201
  l.simplify();
202

203
  auto options = at::TensorOptions()
204
                     .dtype(at::kFloat)
205
                     .layout(at::kStrided)
206
                     .device(at::kCPU)
207
                     .requires_grad(false);
208
  at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5.f;
209
  at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6.f;
210
  at::Tensor bias = at::ones({16}, options) * 11.f;
211
  at::Tensor ref = at::conv2d(
212
      input,
213
      weight,
214
      bias,
215
      {stride, stride},
216
      {pad, pad},
217
      {dilation, dilation},
218
      groups);
219

220
  at::Tensor nnc_result;
221
  std::vector<float> input_buf(1 * 3 * 224 * 224, 5.f);
222
  std::vector<float> weight_buf(16 * 3 * 3 * 3, 6.f);
223
  std::vector<float> bias_buf(16, 11.f);
224
  std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
225

226
#ifdef TORCH_ENABLE_LLVM
227
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
228

229
  llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
230
  nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
231
  ASSERT_TRUE(at::allclose(nnc_result, ref));
232
#endif
233

234
  SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
235

236
  ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
237
  nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
238
  ASSERT_TRUE(at::allclose(nnc_result, ref));
239
}
240

241
TEST(ExternalCall, Conv2d_int) {
242
  // A similar test, but now using kInt tensors
243

244
  BufHandle Input("Input", {1, 3, 224, 224}, kInt);
245
  BufHandle Weight("Weight", {16, 3, 3, 3}, kInt);
246
  BufHandle Bias("Bias", {16}, kInt);
247
  BufHandle ResultBuf("Result", {1, 16, 112, 112}, kInt);
248
  int64_t stride = 2;
249
  int64_t pad = 1;
250
  int64_t dilation = 1;
251
  int64_t groups = 1;
252

253
  Tensor Result = Tensor(
254
      ResultBuf.node(),
255
      ExternalCall::make(
256
          ResultBuf,
257
          "nnc_aten_conv2d",
258
          {Input, Weight, Bias},
259
          {stride, stride, pad, pad, dilation, dilation, groups}));
260
  LoopNest l({Result});
261
  l.prepareForCodegen();
262
  l.simplify();
263

264
  auto options = at::TensorOptions()
265
                     .dtype(at::kInt)
266
                     .layout(at::kStrided)
267
                     .device(at::kCPU)
268
                     .requires_grad(false);
269
  at::Tensor input = at::ones({1, 3, 224, 224}, options) * 5;
270
  at::Tensor weight = at::ones({16, 3, 3, 3}, options) * 6;
271
  at::Tensor bias = at::ones({16}, options) * 11;
272
  at::Tensor ref = at::conv2d(
273
      input,
274
      weight,
275
      bias,
276
      {stride, stride},
277
      {pad, pad},
278
      {dilation, dilation},
279
      groups);
280

281
  at::Tensor nnc_result;
282
  std::vector<int32_t> input_buf(1 * 3 * 224 * 224, 5);
283
  std::vector<int32_t> weight_buf(16 * 3 * 3 * 3, 6);
284
  std::vector<int32_t> bias_buf(16, 11);
285
  std::vector<int32_t> result_buf(1 * 16 * 112 * 112, -1);
286

287
#ifdef TORCH_ENABLE_LLVM
288
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Bias, Result});
289

290
  llvm_codegen.call({input_buf, weight_buf, bias_buf, result_buf});
291
  nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
292
  ASSERT_TRUE(at::allclose(nnc_result, ref));
293
#endif
294

295
  SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Bias, Result});
296

297
  ir_eval.call({input_buf, weight_buf, bias_buf, result_buf});
298
  nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
299
  ASSERT_TRUE(at::allclose(nnc_result, ref));
300
}
301

302
TEST(ExternalCall, Conv2d_nobias_noargs) {
303
  BufHandle Input("Input", {1, 16, 112, 112}, kFloat);
304
  BufHandle Weight("Weight", {16, 16, 1, 1}, kFloat);
305
  BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
306

307
  Tensor Result = Tensor(
308
      ResultBuf.node(),
309
      ExternalCall::make(ResultBuf, "nnc_aten_conv2d", {Input, Weight}, {}));
310
  LoopNest l({Result});
311
  l.prepareForCodegen();
312
  l.simplify();
313

314
  auto options = at::TensorOptions()
315
                     .dtype(at::kFloat)
316
                     .layout(at::kStrided)
317
                     .device(at::kCPU)
318
                     .requires_grad(false);
319
  at::Tensor input = at::ones({1, 16, 112, 112}, options) * 5.f;
320
  at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f;
321
  at::Tensor ref = at::conv2d(input, weight);
322

323
  at::Tensor nnc_result;
324
  std::vector<float> input_buf(1 * 16 * 112 * 112, 5.f);
325
  std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f);
326
  std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
327

328
#ifdef TORCH_ENABLE_LLVM
329
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Weight, Result});
330

331
  llvm_codegen.call({input_buf, weight_buf, result_buf});
332
  nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
333
  ASSERT_TRUE(at::allclose(nnc_result, ref));
334
#endif
335

336
  SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Weight, Result});
337

338
  ir_eval.call({input_buf, weight_buf, result_buf});
339
  nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
340
  ASSERT_TRUE(at::allclose(nnc_result, ref));
341
}
342

343
TEST(ExternalCall, Addmm_float) {
344
  BufHandle Input("Input", {100, 300}, kFloat);
345
  BufHandle Mat1("Mat1", {100, 200}, kFloat);
346
  BufHandle Mat2("Mat2", {200, 300}, kFloat);
347
  BufHandle ResultBuf("Result", {100, 300}, kFloat);
348
  int64_t beta = 2;
349
  int64_t alpha = 2;
350

351
  Tensor Result = Tensor(
352
      ResultBuf.node(),
353
      ExternalCall::make(
354
          ResultBuf, "nnc_aten_addmm", {Input, Mat1, Mat2}, {beta, alpha}));
355
  LoopNest l({Result});
356
  l.prepareForCodegen();
357
  l.simplify();
358

359
  auto options = at::TensorOptions()
360
                     .dtype(at::kFloat)
361
                     .layout(at::kStrided)
362
                     .device(at::kCPU)
363
                     .requires_grad(false);
364
  at::Tensor input = at::ones({100, 300}, options) * 5.f;
365
  at::Tensor mat1 = at::ones({100, 200}, options) * 6.f;
366
  at::Tensor mat2 = at::ones({200, 300}, options) * 11.f;
367
  at::Tensor ref = at::addmm(input, mat1, mat2, beta, alpha);
368

369
  at::Tensor nnc_result;
370
  std::vector<float> input_buf(100 * 300, 5.f);
371
  std::vector<float> mat1_buf(100 * 200, 6.f);
372
  std::vector<float> mat2_buf(200 * 300, 11.f);
373
  std::vector<float> result_buf(100 * 300, -1.f);
374

375
#ifdef TORCH_ENABLE_LLVM
376
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Mat1, Mat2, Result});
377

378
  llvm_codegen.call({input_buf, mat1_buf, mat2_buf, result_buf});
379
  nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
380
  ASSERT_TRUE(at::allclose(nnc_result, ref));
381
#endif
382

383
  SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Mat1, Mat2, Result});
384

385
  ir_eval.call({input_buf, mat1_buf, mat2_buf, result_buf});
386
  nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
387
  ASSERT_TRUE(at::allclose(nnc_result, ref));
388
}
389

390
TEST(ExternalCall, Embedding) {
391
  BufHandle Weight("Weight", {256, 100}, kFloat);
392
  BufHandle Indices("Indices", {1, 115}, kLong);
393
  BufHandle ResultBuf("Result", {1, 115, 100}, kFloat);
394
  int64_t padding_idx = -1;
395
  bool scale_grad_by_freq = false;
396
  bool sparse = false;
397

398
  Tensor Result = Tensor(
399
      ResultBuf.node(),
400
      ExternalCall::make(
401
          ResultBuf,
402
          "nnc_aten_embedding",
403
          {Weight, Indices},
404
          {padding_idx, (int64_t)scale_grad_by_freq, (int64_t)sparse}));
405
  LoopNest l({Result});
406
  l.prepareForCodegen();
407
  l.simplify();
408

409
  auto options = at::TensorOptions()
410
                     .layout(at::kStrided)
411
                     .device(at::kCPU)
412
                     .requires_grad(false);
413

414
  at::Tensor weight = at::ones({256, 100}, options.dtype(at::kFloat)) * 5.f;
415
  at::Tensor indices = at::ones({1, 115}, options.dtype(at::kLong)) * 6;
416
  at::Tensor ref =
417
      at::embedding(weight, indices, padding_idx, scale_grad_by_freq, sparse);
418

419
  at::Tensor nnc_result;
420
  std::vector<float> weight_buf(256 * 100, 5.f);
421
  std::vector<int64_t> indices_buf(1 * 115, 6);
422
  std::vector<float> result_buf(1 * 115 * 100, -1.f);
423

424
#ifdef TORCH_ENABLE_LLVM
425
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Weight, Indices, Result});
426

427
  llvm_codegen.call({weight_buf, indices_buf, result_buf});
428
  nnc_result = at::from_blob(
429
      result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
430
  ASSERT_TRUE(at::allclose(nnc_result, ref));
431
#endif
432

433
  SimpleIREvaluator ir_eval(l.root_stmt(), {Weight, Indices, Result});
434

435
  ir_eval.call({weight_buf, indices_buf, result_buf});
436
  nnc_result = at::from_blob(
437
      result_buf.data(), {1, 115, 100}, options.dtype(at::kFloat));
438
  ASSERT_TRUE(at::allclose(nnc_result, ref));
439
}
440

441
TEST(ExternalCall, MaxReduction) {
442
  BufHandle Input("Input", {1, 115, 152}, kFloat);
443
  BufHandle ResultBuf("Result", {1, 152}, kFloat);
444
  int64_t dim = 1;
445
  bool keep_dim = false;
446

447
  Tensor Result = Tensor(
448
      ResultBuf.node(),
449
      ExternalCall::make(
450
          ResultBuf, "nnc_aten_max_red", {Input}, {dim, (int64_t)keep_dim}));
451
  LoopNest l({Result});
452
  l.prepareForCodegen();
453
  l.simplify();
454

455
  auto options = at::TensorOptions()
456
                     .dtype(at::kFloat)
457
                     .layout(at::kStrided)
458
                     .device(at::kCPU)
459
                     .requires_grad(false);
460

461
  at::Tensor input = at::ones({1, 115, 152}, options) * 5.f;
462
  at::Tensor ref = std::get<0>(at::max(input, dim, keep_dim));
463

464
  at::Tensor nnc_result;
465
  std::vector<float> input_buf(1 * 115 * 152, 5.f);
466
  std::vector<float> result_buf(1 * 152, -1.f);
467

468
#ifdef TORCH_ENABLE_LLVM
469
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, Result});
470

471
  llvm_codegen.call({input_buf, result_buf});
472
  nnc_result = at::from_blob(result_buf.data(), {1, 152}, options);
473
  ASSERT_TRUE(at::allclose(nnc_result, ref));
474
#endif
475

476
  SimpleIREvaluator ir_eval(l.root_stmt(), {Input, Result});
477

478
  ir_eval.call({input_buf, result_buf});
479
  nnc_result = at::from_blob(result_buf.data(), {1, 152}, options);
480
  ASSERT_TRUE(at::allclose(nnc_result, ref));
481
}
482

483
#ifdef USE_XNNPACK
484

485
TEST(ExternalCall, Prepacked_Linear_float) {
486
  using namespace at::native::xnnpack;
487

488
  BufHandle Input("Input", {100, 200}, kFloat);
489
  BufHandle ResultBuf("Result", {100, 300}, kFloat);
490

491
  // Calculate reference result using at::linear.
492
  auto options = at::TensorOptions()
493
                     .dtype(at::kFloat)
494
                     .layout(at::kStrided)
495
                     .device(at::kCPU)
496
                     .requires_grad(false);
497
  at::Tensor input =
498
      at::linspace(-10.0, 10.0, 100 * 200, options).resize_({100, 200});
499
  at::Tensor weight =
500
      at::linspace(-10.0, 10.0, 300 * 200, options).resize_({300, 200});
501
  at::Tensor bias = at::linspace(-10.0, 10.0, 300, options);
502
  at::Tensor ref = at::linear(input, weight, bias);
503

504
  // Create prepacked xnnpack context object.
505
  auto linear_clamp_prepack_op =
506
      c10::Dispatcher::singleton()
507
          .findSchemaOrThrow("prepacked::linear_clamp_prepack", "")
508
          .typed<c10::intrusive_ptr<LinearOpContext>(
509
              at::Tensor,
510
              std::optional<at::Tensor>,
511
              const std::optional<at::Scalar>&,
512
              const std::optional<at::Scalar>&)>();
513
  auto prepacked = linear_clamp_prepack_op.call(
514
      weight, bias, std::optional<at::Scalar>(), std::optional<at::Scalar>());
515

516
  BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat);
517
  Tensor Result = Tensor(
518
      ResultBuf.node(),
519
      ExternalCall::make(
520
          ResultBuf,
521
          "nnc_prepacked_linear_clamp_run",
522
          {Input, DummyPrepacked},
523
          {}));
524
  LoopNest l({Result});
525
  l.prepareForCodegen();
526
  l.simplify();
527

528
  at::Tensor nnc_result;
529
  std::vector<float> input_buf(
530
      input.data_ptr<float>(), input.data_ptr<float>() + 100 * 200);
531
  std::vector<float> result_buf(100 * 300, -1.f);
532

533
#ifdef TORCH_ENABLE_LLVM
534
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result});
535

536
  llvm_codegen.call({input_buf, prepacked.get(), result_buf});
537
  nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
538
  ASSERT_TRUE(at::allclose(nnc_result, ref));
539
#endif
540

541
  SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result});
542

543
  ir_eval.call({input_buf, prepacked.get(), result_buf});
544
  nnc_result = at::from_blob(result_buf.data(), {100, 300}, options);
545
  ASSERT_TRUE(at::allclose(nnc_result, ref));
546
}
547

548
TEST(ExternalCall, Prepacked_Conv2d_float) {
549
  using namespace at::native::xnnpack;
550

551
  BufHandle Input("Input", {1, 3, 224, 224}, kFloat);
552
  BufHandle ResultBuf("Result", {1, 16, 112, 112}, kFloat);
553
  int64_t stride = 2;
554
  int64_t pad = 1;
555
  int64_t dilation = 1;
556
  int64_t groups = 1;
557

558
  // Calculate reference result using at::conv2d.
559
  auto options = at::TensorOptions()
560
                     .dtype(at::kFloat)
561
                     .layout(at::kStrided)
562
                     .device(at::kCPU)
563
                     .requires_grad(false);
564
  at::Tensor input = at::linspace(-10.0, 10.0, 1 * 3 * 224 * 224, options)
565
                         .resize_({1, 3, 224, 224});
566
  at::Tensor weight =
567
      at::linspace(-10.0, 10.0, 16 * 3 * 3 * 3, options).resize_({16, 3, 3, 3});
568
  at::Tensor bias = at::linspace(-10.0, 10.0, 16, options);
569
  at::Tensor ref = at::conv2d(
570
      input,
571
      weight,
572
      bias,
573
      {stride, stride},
574
      {pad, pad},
575
      {dilation, dilation},
576
      groups);
577

578
  // Create prepacked xnnpack context object.
579
  auto conv2d_clamp_prepack_op =
580
      c10::Dispatcher::singleton()
581
          .findSchemaOrThrow("prepacked::conv2d_clamp_prepack", "")
582
          .typed<c10::intrusive_ptr<Conv2dOpContext>(
583
              at::Tensor,
584
              std::optional<at::Tensor>,
585
              std::vector<int64_t>,
586
              std::vector<int64_t>,
587
              std::vector<int64_t>,
588
              int64_t,
589
              const std::optional<at::Scalar>&,
590
              const std::optional<at::Scalar>&)>();
591
  auto prepacked = conv2d_clamp_prepack_op.call(
592
      weight,
593
      bias,
594
      {stride, stride},
595
      {pad, pad},
596
      {dilation, dilation},
597
      groups,
598
      std::optional<at::Scalar>(),
599
      std::optional<at::Scalar>());
600

601
  BufHandle DummyPrepacked("DummyPrepacked", {1}, kFloat);
602
  Tensor Result = Tensor(
603
      ResultBuf.node(),
604
      ExternalCall::make(
605
          ResultBuf,
606
          "nnc_prepacked_conv2d_clamp_run",
607
          {Input, DummyPrepacked},
608
          {}));
609
  LoopNest l({Result});
610
  l.prepareForCodegen();
611
  l.simplify();
612

613
  at::Tensor nnc_result;
614
  std::vector<float> input_buf(
615
      input.data_ptr<float>(), input.data_ptr<float>() + 1 * 3 * 224 * 224);
616
  std::vector<float> result_buf(1 * 16 * 112 * 112, -1.f);
617

618
#ifdef TORCH_ENABLE_LLVM
619
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Input, DummyPrepacked, Result});
620

621
  llvm_codegen.call({input_buf, prepacked.get(), result_buf});
622
  nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
623
  ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03));
624
#endif
625

626
  SimpleIREvaluator ir_eval(l.root_stmt(), {Input, DummyPrepacked, Result});
627

628
  ir_eval.call({input_buf, prepacked.get(), result_buf});
629
  nnc_result = at::from_blob(result_buf.data(), {1, 16, 112, 112}, options);
630
  ASSERT_TRUE(at::allclose(nnc_result, ref, 1e-03, 1e-03));
631
}
632

633
#endif // USE_XNNPACK
634

635
TEST(ExternalCall, BinaryFloat) {
636
  using TensorFunc = std::function<at::Tensor(at::Tensor, at::Tensor)>;
637
  using Test = std::tuple<
638
      std::vector<int64_t>,
639
      std::vector<int64_t>,
640
      std::vector<int64_t>,
641
      TensorFunc,
642
      std::string>;
643
  std::vector<Test> tests = {};
644
  tests.push_back(
645
      Test{{100, 200}, {200, 300}, {100, 300}, at::matmul, "nnc_aten_matmul"});
646
  tests.push_back(Test{{100, 300}, {300}, {100}, at::mv, "nnc_aten_mv"});
647
  tests.push_back(
648
      Test{{100, 200}, {200, 300}, {100, 300}, at::mm, "nnc_aten_mm"});
649
  for (auto curTest : tests) {
650
    auto [aShape, bShape, resShape, torchFunc, externCallName] = curTest;
651
    auto toExprHandleVec = [](std::vector<int64_t> v) {
652
      auto intV = std::vector<int>(v.begin(), v.end());
653
      return std::vector<ExprHandle>(intV.begin(), intV.end());
654
    };
655
    BufHandle A("A", toExprHandleVec(aShape), kFloat);
656
    BufHandle B("B", toExprHandleVec(bShape), kFloat);
657
    BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);
658

659
    Tensor Result = Tensor(
660
        ResultBuf.node(),
661
        ExternalCall::make(ResultBuf, externCallName, {A, B}, {}));
662
    LoopNest l({Result});
663
    l.prepareForCodegen();
664
    l.simplify();
665

666
    auto options = at::TensorOptions()
667
                       .dtype(at::kFloat)
668
                       .layout(at::kStrided)
669
                       .device(at::kCPU)
670
                       .requires_grad(false);
671
    at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;
672
    at::Tensor b = at::ones(c10::IntArrayRef(bShape), options) * 6.f;
673
    at::Tensor ref = torchFunc(a, b);
674

675
    auto prod = [](std::vector<int64_t> v) {
676
      // NOLINTNEXTLINE(modernize-use-transparent-functors)
677
      return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());
678
    };
679

680
    at::Tensor nnc_result;
681
    std::vector<float> a_buf(prod(aShape), 5.f);
682
    std::vector<float> b_buf(prod(bShape), 6.f);
683
    std::vector<float> result_buf(prod(resShape), -1.f);
684

685
#ifdef TORCH_ENABLE_LLVM
686
    LLVMCodeGen llvm_codegen(l.root_stmt(), {A, B, Result});
687

688
    llvm_codegen.call({a_buf, b_buf, result_buf});
689
    nnc_result =
690
        at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
691
    ASSERT_TRUE(at::allclose(nnc_result, ref));
692
#endif
693

694
    SimpleIREvaluator ir_eval(l.root_stmt(), {A, B, Result});
695
    ir_eval.call({a_buf, b_buf, result_buf});
696
    nnc_result =
697
        at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
698
    ASSERT_TRUE(at::allclose(nnc_result, ref));
699
  }
700
}
701

702
TEST(ExternalCall, UnaryFloat) {
703
  using TensorFunc = std::function<at::Tensor(at::Tensor)>;
704
  auto toExprHandleVec = [](std::vector<int64_t> v) {
705
    auto intV = std::vector<int>(v.begin(), v.end());
706
    return std::vector<ExprHandle>(intV.begin(), intV.end());
707
  };
708
  using Test = std::tuple<
709
      std::vector<int64_t>,
710
      std::vector<int64_t>,
711
      TensorFunc,
712
      std::string,
713
      std::vector<ExprHandle>>;
714
  std::vector<Test> tests = {};
715
  tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
716
                       {1, 64, 8, 9},
717
                       {1, 64, 5, 7},
718
                       [](at::Tensor x) {
719
                         return at::adaptive_avg_pool2d(x, {5, 7});
720
                       },
721
                       "nnc_aten_adaptive_avg_pool2d",
722
                       toExprHandleVec({5, 7})});
723
  tests.push_back(Test{// NOLINTNEXTLINE(cppcoreguidelines-avoid-magic-numbers)
724
                       {100, 200},
725
                       {100},
726
                       [](at::Tensor x) { return at::mean(x, {1}); },
727
                       "nnc_aten_mean",
728
                       toExprHandleVec({1, /*keepdim=*/0})});
729
  for (auto curTest : tests) {
730
    auto [aShape, resShape, torchFunc, externCallName, externCallArgs] =
731
        curTest;
732
    BufHandle A("A", toExprHandleVec(aShape), kFloat);
733
    BufHandle ResultBuf("Result", toExprHandleVec(resShape), kFloat);
734

735
    Tensor Result = Tensor(
736
        ResultBuf.node(),
737
        ExternalCall::make(ResultBuf, externCallName, {A}, externCallArgs));
738
    LoopNest l({Result});
739
    l.prepareForCodegen();
740
    l.simplify();
741

742
    auto options = at::TensorOptions()
743
                       .dtype(at::kFloat)
744
                       .layout(at::kStrided)
745
                       .device(at::kCPU)
746
                       .requires_grad(false);
747
    at::Tensor a = at::ones(c10::IntArrayRef(aShape), options) * 5.f;
748
    at::Tensor ref = torchFunc(a);
749

750
    auto prod = [](std::vector<int64_t> v) {
751
      // NOLINTNEXTLINE(modernize-use-transparent-functors)
752
      return std::accumulate(v.begin(), v.end(), 1, std::multiplies<int64_t>());
753
    };
754

755
    at::Tensor nnc_result;
756
    std::vector<float> a_buf(prod(aShape), 5.f);
757
    std::vector<float> result_buf(prod(resShape), -1.f);
758

759
#ifdef TORCH_ENABLE_LLVM
760
    LLVMCodeGen llvm_codegen(l.root_stmt(), {A, Result});
761

762
    llvm_codegen.call({a_buf, result_buf});
763
    nnc_result =
764
        at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
765
    ASSERT_TRUE(at::allclose(nnc_result, ref));
766
#endif
767

768
    SimpleIREvaluator ir_eval(l.root_stmt(), {A, Result});
769
    ir_eval.call({a_buf, result_buf});
770
    nnc_result =
771
        at::from_blob(result_buf.data(), c10::IntArrayRef(resShape), options);
772
    ASSERT_TRUE(at::allclose(nnc_result, ref));
773
  }
774
}
775

776
TEST(ExternalCall, ComputeInterop) {
777
  // This test verifies that Tensors using external calls can be used by and can
778
  // use Tensors built with Compute API.
779

780
  BufHandle ConvResultBuf("ConvResult", {1, 16, 32, 32}, kFloat);
781
  BufHandle MatmulResultBuf("MatmulResult", {1, 16, 32, 32}, kFloat);
782

783
  Tensor Input = Compute(
784
      "Input",
785
      {1, 16, 32, 32},
786
      [&](const VarHandle& n,
787
          const VarHandle& c,
788
          const VarHandle& h,
789
          const VarHandle& w) { return FloatImm::make(5.0f); });
790
  Tensor Weight = Compute(
791
      "Weight",
792
      {16, 16, 1, 1},
793
      [&](const VarHandle& n,
794
          const VarHandle& c,
795
          const VarHandle& h,
796
          const VarHandle& w) { return FloatImm::make(6.0f); });
797

798
  Tensor ConvResult = Tensor(
799
      ConvResultBuf.node(),
800
      ExternalCall::make(
801
          ConvResultBuf,
802
          "nnc_aten_conv2d",
803
          {BufHandle(Input.buf()), BufHandle(Weight.buf())},
804
          {}));
805
  Tensor MatmulResult = Tensor(
806
      MatmulResultBuf.node(),
807
      ExternalCall::make(
808
          MatmulResultBuf,
809
          "nnc_aten_matmul",
810
          {BufHandle(ConvResult.buf()), BufHandle(ConvResult.buf())},
811
          {}));
812
  Tensor Result = Compute(
813
      "Result",
814
      {1, 16, 32, 32},
815
      [&](const VarHandle& n,
816
          const VarHandle& c,
817
          const VarHandle& h,
818
          const VarHandle& w) {
819
        return ConvResult.load(n, c, h, w) + MatmulResult.load(n, c, h, w);
820
      });
821

822
  LoopNest l({Input, Weight, ConvResult, MatmulResult, Result});
823

824
  // Inlining should not inline anything here since all Bufs are either defined
825
  // or used in ExternalCalls - we run it just for testing
826
  l.inlineIntermediateBufs(true);
827

828
  l.prepareForCodegen();
829
  l.simplify();
830

831
  auto options = at::TensorOptions()
832
                     .dtype(at::kFloat)
833
                     .layout(at::kStrided)
834
                     .device(at::kCPU)
835
                     .requires_grad(false);
836
  at::Tensor input = at::ones({1, 16, 32, 32}, options) * 5.f;
837
  at::Tensor weight = at::ones({16, 16, 1, 1}, options) * 6.f;
838
  at::Tensor t = at::conv2d(input, weight);
839
  at::Tensor t2 = at::matmul(t, t);
840
  at::Tensor ref = t + t2;
841

842
  at::Tensor nnc_result;
843
  std::vector<float> input_buf(1 * 16 * 32 * 32, 5.f);
844
  std::vector<float> weight_buf(16 * 16 * 1 * 1, 6.f);
845
  std::vector<float> conv_result_buf(1 * 16 * 32 * 32, -1.f);
846
  std::vector<float> matmul_result_buf(1 * 16 * 32 * 32, -1.f);
847
  std::vector<float> result_buf(1 * 16 * 32 * 32, -1.f);
848

849
#ifdef TORCH_ENABLE_LLVM
850
  LLVMCodeGen llvm_codegen(
851
      l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result});
852

853
  llvm_codegen.call(
854
      {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf});
855
  nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options);
856
  ASSERT_TRUE(at::allclose(nnc_result, ref));
857
#endif
858

859
  SimpleIREvaluator ir_eval(
860
      l.root_stmt(), {Input, Weight, ConvResult, MatmulResult, Result});
861

862
  ir_eval.call(
863
      {input_buf, weight_buf, conv_result_buf, matmul_result_buf, result_buf});
864
  nnc_result = at::from_blob(result_buf.data(), {1, 16, 32, 32}, options);
865
  ASSERT_TRUE(at::allclose(nnc_result, ref));
866
}
867

868
TEST(ExternalCall, Inlining) {
869
  // This test verifies that Tensors using external calls can be used by and
870
  // can use Tensors built with Compute API.
871

872
  BufHandle MatmulResultBuf("MatmulResult", {8, 8}, kFloat);
873

874
  Tensor A = Compute("A", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
875
    return FloatImm::make(5.0f);
876
  });
877
  Tensor B = Compute("B", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
878
    return FloatImm::make(4.0f);
879
  });
880
  Tensor MatmulResult = Tensor(
881
      MatmulResultBuf.node(),
882
      ExternalCall::make(
883
          MatmulResultBuf,
884
          "nnc_aten_matmul",
885
          {BufHandle(A.buf()), BufHandle(B.buf())},
886
          {}));
887
  Tensor Result =
888
      Compute("Result", {8, 8}, [&](const VarHandle& i, const VarHandle& j) {
889
        return MatmulResult.load(i, j) + FloatImm::make(3.0f);
890
      });
891

892
  StmtPtr root_stmt = alloc<torch::jit::tensorexpr::Block>(std::vector<StmtPtr>(
893
      {A.stmt(), B.stmt(), MatmulResult.stmt(), Result.stmt()}));
894
  LoopNest l(root_stmt, {Result.buf()});
895

896
  // Inlining should not inline anything here since all Bufs are either
897
  // defined or used in ExternalCalls
898
  l.inlineIntermediateBufs(false);
899

900
  l.prepareForCodegen();
901
  l.simplify();
902

903
  auto options = at::TensorOptions()
904
                     .dtype(at::kFloat)
905
                     .layout(at::kStrided)
906
                     .device(at::kCPU)
907
                     .requires_grad(false);
908
  at::Tensor a = at::ones({8, 8}, options) * 5.f;
909
  at::Tensor b = at::ones({8, 8}, options) * 4.f;
910
  at::Tensor t = at::matmul(a, b);
911
  at::Tensor ref = t + 3.f;
912

913
  at::Tensor nnc_result;
914
  std::vector<float> result_buf(8 * 8);
915

916
#ifdef TORCH_ENABLE_LLVM
917
  LLVMCodeGen llvm_codegen(l.root_stmt(), {Result});
918

919
  llvm_codegen.call({result_buf});
920
  nnc_result = at::from_blob(result_buf.data(), {8, 8}, options);
921
  ASSERT_TRUE(at::allclose(nnc_result, ref));
922
#endif
923

924
  SimpleIREvaluator ir_eval(l.root_stmt(), {Result});
925

926
  ir_eval.call({result_buf});
927
  nnc_result = at::from_blob(result_buf.data(), {8, 8}, options);
928
  ASSERT_TRUE(at::allclose(nnc_result, ref));
929
}
930

931
TEST(ExternalCall, JitCustomFusionOp) {
932
  const char* custom_op_schema_literal =
933
      "nnc_custom::add_mul(Tensor a, Tensor b, Tensor c) -> Tensor";
934
  const char* external_func_name = "nnc_add_mul";
935

936
  auto add_mul_lowering_func =
937
      [external_func_name](
938
          const std::vector<torch::jit::tensorexpr::ArgValue>& inputs,
939
          const std::vector<torch::jit::tensorexpr::ExprHandle>& output_shape,
940
          const std::vector<torch::jit::tensorexpr::ExprHandle>& output_strides,
941
          const std::optional<torch::jit::tensorexpr::ScalarType>& output_type,
942
          at::Device device) {
943
        auto output_dtype = Dtype(*output_type);
944
        torch::jit::tensorexpr::BufHandle result_buf(
945
            "nnc_add_mul_res_buf", output_shape, output_dtype);
946
        const torch::jit::tensorexpr::BufHandle& a =
947
            std::get<torch::jit::tensorexpr::BufHandle>(inputs[0]);
948
        const torch::jit::tensorexpr::BufHandle& b =
949
            std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
950
        const torch::jit::tensorexpr::BufHandle& c =
951
            std::get<torch::jit::tensorexpr::BufHandle>(inputs[1]);
952
        torch::jit::tensorexpr::StmtPtr s =
953
            torch::jit::tensorexpr::ExternalCall::make(
954
                result_buf, external_func_name, {a, b, c}, {});
955
        return Tensor(result_buf.node(), s);
956
      };
957

958
  auto add_mul_external_func = [](int64_t bufs_num,
959
                                  void** buf_data,
960
                                  int64_t* buf_ranks,
961
                                  int64_t* buf_dims,
962
                                  int64_t* buf_strides,
963
                                  int8_t* buf_dtypes,
964
                                  int64_t args_num,
965
                                  int64_t* extra_args) {};
966

967
  torch::jit::RegisterOperators reg({Operator(
968
      custom_op_schema_literal,
969
      [](const Node* node) -> Operation {
970
        return [](Stack& _stack) {
971
          auto a = std::move(peek(_stack, 0, 3)).toTensor();
972
          auto b = std::move(peek(_stack, 1, 3)).toTensor();
973
          auto c = std::move(peek(_stack, 2, 3)).toTensor();
974
          drop(_stack, 3);
975
          auto result = (a + b) * c;
976
          pack(_stack, std::move(result));
977
          return 0;
978
        };
979
      },
980
      c10::AliasAnalysisKind::FROM_SCHEMA)});
981

982
  auto& custom_operator_set = torch::jit::tensorexpr::getCustomOperatorSet();
983
  custom_operator_set.insert({custom_op_schema_literal});
984

985
  auto& te_lowering_registry = torch::jit::tensorexpr::getNNCLoweringRegistry();
986
  te_lowering_registry.insert(
987
      parseSchema(custom_op_schema_literal), add_mul_lowering_func);
988

989
  auto& te_nnc_func_registry = torch::jit::tensorexpr::getNNCFunctionRegistry();
990
  te_nnc_func_registry[external_func_name] = add_mul_external_func;
991

992
  std::string graph_string = R"IR(
993
    graph(%a : Float(10, 20, strides=[20, 1], device=cpu),
994
          %b : Float(10, 20, strides=[20, 1], device=cpu),
995
          %c : Float(10, 20, strides=[20, 1], device=cpu)):
996
      %res : Float(10, 20, strides=[20, 1], device=cpu) = nnc_custom::add_mul(%a, %b, %c)
997
      return (%res))IR";
998

999
  auto graph = std::make_shared<Graph>();
1000
  torch::jit::parseIR(graph_string, graph.get());
1001

1002
  std::string shape_compute_python_string = R"PY(
1003
  def computOutput(a: List[int], b: List[int], c: List[int]):
1004
    expandedSizes: List[int] = []
1005
    dimsA = len(a)
1006
    dimsB = len(b)
1007
    dimsC = len(c)
1008
    ndim = max(dimsA, dimsB, dimsC)
1009
    for i in range(ndim):
1010
        offset = ndim - 1 - i
1011
        dimA = dimsA - 1 - offset
1012
        dimB = dimsB - 1 - offset
1013
        dimC = dimsC - 1 - offset
1014
        sizeA = a[dimA] if (dimA >= 0) else 1
1015
        sizeB = b[dimB] if (dimB >= 0) else 1
1016
        sizeC = a[dimC] if (dimC >= 0) else 1
1017

1018
        if sizeA != sizeB and sizeB != sizeC and sizeA != 1 and sizeB != 1 and sizeC != 1:
1019
            # TODO: only assertion error is bound in C++ compilation right now
1020
            raise AssertionError(
1021
                "The size of tensor a {} must match the size of tensor b ("
1022
                "{} and c {}) at non-singleton dimension {}".format(sizeA, sizeB, sizeC, i)
1023
            )
1024

1025
        expandedSizes.append(max(sizeA, sizeB, sizeC))
1026

1027
    return expandedSizes
1028
  )PY";
1029
  auto cu_ptr = torch::jit::compile(shape_compute_python_string);
1030
  torch::jit::GraphFunction* gf =
1031
      (torch::jit::GraphFunction*)&cu_ptr->get_function("computOutput");
1032
  ASSERT_TRUE(gf);
1033

1034
#ifdef TORCH_ENABLE_LLVM
1035
  auto static_graph_case = graph->copy();
1036
  FuseTensorExprs(static_graph_case, 1);
1037
  torch::jit::testing::FileCheck()
1038
      .check("prim::TensorExprGroup_")
1039
      ->check("nnc_custom::add_mul")
1040
      ->run(*static_graph_case);
1041

1042
  auto dynamic_graph_case = graph->copy();
1043
  auto custom_op = torch::jit::getOperatorForLiteral(custom_op_schema_literal);
1044
  ASSERT_TRUE(custom_op);
1045
  torch::jit::RegisterShapeComputeGraphForSchema(
1046
      custom_op->schema(), gf->graph());
1047
  FuseTensorExprs(dynamic_graph_case, 1, false, true);
1048
  torch::jit::testing::FileCheck()
1049
      .check("prim::TensorExprGroup_")
1050
      ->check("nnc_custom::add_mul")
1051
      ->run(*dynamic_graph_case);
1052
#else
1053
  torch::jit::testing::FileCheck().check("nnc_custom::add_mul")->run(*graph);
1054
#endif
1055
}
1056

1057
} // namespace jit
1058
} // namespace torch
1059

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.