pytorch
90 строк · 2.3 Кб
1#include <gtest/gtest.h>
2
3#include <stdexcept>
4#include "test/cpp/tensorexpr/test_base.h"
5
6#include <torch/csrc/jit/tensorexpr/expr.h>
7#include <torch/csrc/jit/tensorexpr/ir.h>
8#include <torch/csrc/jit/tensorexpr/ir_printer.h>
9#include <torch/csrc/jit/tensorexpr/loopnest.h>
10#include <torch/csrc/jit/tensorexpr/tensor.h>
11#include <torch/csrc/jit/testing/file_check.h>
12
13#include <sstream>
14namespace torch {
15namespace jit {
16
17using namespace torch::jit::tensorexpr;
18
19TEST(IRPrinter, BasicValueTest) {
20ExprHandle a = IntImm::make(2), b = IntImm::make(3);
21ExprHandle c = Add::make(a, b);
22
23std::stringstream ss;
24ss << c;
25ASSERT_EQ(ss.str(), "2 + 3");
26}
27
28TEST(IRPrinter, BasicValueTest02) {
29ExprHandle a(2.0f);
30ExprHandle b(3.0f);
31ExprHandle c(4.0f);
32ExprHandle d(5.0f);
33ExprHandle f = (a + b) - (c + d);
34
35std::stringstream ss;
36ss << f;
37ASSERT_EQ(ss.str(), "(2.f + 3.f) - (4.f + 5.f)");
38}
39
40TEST(IRPrinter, CastTest) {
41VarHandle x("x", kHalf);
42VarHandle y("y", kFloat);
43ExprHandle body = ExprHandle(2.f) +
44(Cast::make(kFloat, x) * ExprHandle(3.f) + ExprHandle(4.f) * y);
45
46std::stringstream ss;
47ss << body;
48ASSERT_EQ(ss.str(), "2.f + (float(x) * 3.f + 4.f * y)");
49}
50
51TEST(IRPrinter, FunctionName) {
52int M = 4;
53int N = 20;
54
55Tensor producer = Compute(
56"producer", {M, N}, [&](const ExprHandle& m, const ExprHandle& n) {
57return m * n;
58});
59
60Tensor chunk_0 = Compute(
61"chunk_0", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
62return producer.load(m, n);
63});
64
65Tensor chunk_1 = Compute(
66"chunk_1", {M, N / 2}, [&](const ExprHandle& m, const ExprHandle& n) {
67return producer.load(m, n + ExprHandle(N / 2));
68});
69
70Tensor consumer = Compute(
71"consumer", {M, N / 2}, [&](const ExprHandle& i, const ExprHandle& j) {
72return i * chunk_1.load(i, j);
73});
74
75LoopNest l({chunk_0, chunk_1, consumer});
76auto body = LoopNest::sanitizeNames(l.root_stmt());
77
78std::stringstream ss;
79ss << *body;
80
81const std::string& verification_pattern =
82R"IR(
83# CHECK: for (int i_2
84# CHECK: for (int j_2
85# CHECK: consumer[i_2, j_2] = i_2 * (chunk_1[i_2, j_2])IR";
86
87torch::jit::testing::FileCheck().run(verification_pattern, ss.str());
88}
89} // namespace jit
90} // namespace torch
91