pytorch

Форк
0
/
nested_int.cpp 
105 строк · 3.2 Кб
1
#include <gtest/gtest.h>
2

3
#include <ATen/core/NestedIntSymNodeImpl.h>
4
#include <c10/core/SymInt.h>
5
#include <c10/core/SymNodeImpl.h>
6
#include <torch/torch.h>
7

8
#include <test/cpp/api/support.h>
9

10
TEST(NestedIntTest, Comparisons) {
11
  auto a = c10::SymInt(
12
      c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
13
  auto b = c10::SymInt(
14
      c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 1)));
15
  auto c = c10::SymInt(
16
      c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(2, 1)));
17
  auto d = c10::SymInt(3);
18

19
  ASSERT_TRUE(a == a);
20
  ASSERT_TRUE(a == b);
21
  ASSERT_FALSE(a != a);
22
  ASSERT_FALSE(a != b);
23
  ASSERT_FALSE(a == c);
24
  ASSERT_TRUE(a != c);
25

26
  ASSERT_FALSE(a == d);
27
  ASSERT_TRUE(a != d);
28
  ASSERT_FALSE(d == a);
29
  ASSERT_TRUE(d != a);
30

31
  // ge
32
  ASSERT_TRUE(a >= a);
33
  ASSERT_TRUE(a >= b);
34
  ASSERT_TRUE(b >= a);
35
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
36
  EXPECT_THROW((void)(a >= c), c10::Error);
37
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
38
  EXPECT_THROW((void)(c >= a), c10::Error);
39
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
40
  EXPECT_THROW((void)(c >= 3), c10::Error);
41
  ASSERT_TRUE(c >= 2);
42
  ASSERT_TRUE(c >= 1);
43
  ASSERT_FALSE(1 >= c);
44

45
  // lt
46
  ASSERT_FALSE(a < a);
47
  ASSERT_FALSE(a < b);
48
  ASSERT_FALSE(b < a);
49
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
50
  EXPECT_THROW((void)(a < c), c10::Error);
51
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
52
  EXPECT_THROW((void)(c < a), c10::Error);
53
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
54
  EXPECT_THROW((void)(3 < a), c10::Error);
55
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
56
  EXPECT_THROW((void)(2 < a), c10::Error);
57
  ASSERT_TRUE(1 < a);
58

59
  // le
60
  ASSERT_TRUE(a <= a);
61
  ASSERT_TRUE(b <= a);
62
  ASSERT_TRUE(a <= b);
63
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
64
  EXPECT_THROW((void)(a <= c), c10::Error);
65
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
66
  EXPECT_THROW((void)(c <= a), c10::Error);
67
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
68
  EXPECT_THROW((void)(3 <= c), c10::Error);
69
  ASSERT_TRUE(2 <= c);
70
  ASSERT_TRUE(1 <= c);
71
  ASSERT_FALSE(c <= 1);
72

73
  // gt
74
  ASSERT_FALSE(a > a);
75
  ASSERT_FALSE(b > a);
76
  ASSERT_FALSE(a > b);
77
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
78
  EXPECT_THROW((void)(a > c), c10::Error);
79
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
80
  EXPECT_THROW((void)(c > a), c10::Error);
81
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
82
  EXPECT_THROW((void)(a > 3), c10::Error);
83
  // NOLINTNEXTLINE(hicpp-avoid-goto,cppcoreguidelines-avoid-goto)
84
  EXPECT_THROW((void)(a > 2), c10::Error);
85
  ASSERT_TRUE(a > 1);
86
}
87

88
TEST(NestedIntTest, WithFactor) {
89
  auto a = c10::SymInt(
90
      c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 5)));
91
  auto b = c10::SymInt(
92
      c10::SymNode(c10::make_intrusive<c10::NestedIntSymNodeImpl>(1, 10)));
93
  // eq
94
  ASSERT_FALSE(a == b);
95
  ASSERT_FALSE(a >= b);
96
  ASSERT_TRUE(b >= a);
97
  ASSERT_TRUE(a <= b);
98
  ASSERT_FALSE(b <= a);
99
  // ne
100
  ASSERT_TRUE(a != b);
101
  // mul
102
  ASSERT_TRUE(a * 2 == b);
103
  ASSERT_TRUE(a * 3 >= b);
104
  ASSERT_TRUE(a * 2 == 2 * a);
105
}
106

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.