pytorch

Форк
0
/
misc.cpp 
104 строки · 2.4 Кб
1
#include <gtest/gtest.h>
2

3
#include <torch/torch.h>
4

5
#include <test/cpp/api/support.h>
6

7
#include <functional>
8

9
using namespace torch::test;
10

11
void torch_warn_once_A() {
12
  TORCH_WARN_ONCE("warn once");
13
}
14

15
void torch_warn_once_B() {
16
  TORCH_WARN_ONCE("warn something else once");
17
}
18

19
void torch_warn() {
20
  TORCH_WARN("warn multiple times");
21
}
22

23
TEST(UtilsTest, WarnOnce) {
24
  {
25
    WarningCapture warnings;
26

27
    torch_warn_once_A();
28
    torch_warn_once_A();
29
    torch_warn_once_B();
30
    torch_warn_once_B();
31

32
    ASSERT_EQ(count_substr_occurrences(warnings.str(), "warn once"), 1);
33
    ASSERT_EQ(
34
        count_substr_occurrences(warnings.str(), "warn something else once"),
35
        1);
36
  }
37
  {
38
    WarningCapture warnings;
39

40
    torch_warn();
41
    torch_warn();
42
    torch_warn();
43

44
    ASSERT_EQ(
45
        count_substr_occurrences(warnings.str(), "warn multiple times"), 3);
46
  }
47
}
48

49
TEST(NoGradTest, SetsGradModeCorrectly) {
50
  torch::manual_seed(0);
51
  torch::NoGradGuard guard;
52
  torch::nn::Linear model(5, 2);
53
  auto x = torch::randn({10, 5}, torch::requires_grad());
54
  auto y = model->forward(x);
55
  torch::Tensor s = y.sum();
56

57
  // Mimicking python API behavior:
58
  ASSERT_THROWS_WITH(
59
      s.backward(),
60
      "element 0 of tensors does not require grad and does not have a grad_fn")
61
}
62

63
struct AutogradTest : torch::test::SeedingFixture {
64
  AutogradTest() {
65
    x = torch::randn({3, 3}, torch::requires_grad());
66
    y = torch::randn({3, 3});
67
    z = x * y;
68
  }
69
  torch::Tensor x, y, z;
70
};
71

72
TEST_F(AutogradTest, CanTakeDerivatives) {
73
  z.backward(torch::ones_like(z));
74
  ASSERT_TRUE(x.grad().allclose(y));
75
}
76

77
TEST_F(AutogradTest, CanTakeDerivativesOfZeroDimTensors) {
78
  z.sum().backward();
79
  ASSERT_TRUE(x.grad().allclose(y));
80
}
81

82
TEST_F(AutogradTest, CanPassCustomGradientInputs) {
83
  z.sum().backward(torch::ones({}) * 2);
84
  ASSERT_TRUE(x.grad().allclose(y * 2));
85
}
86

87
TEST(UtilsTest, AmbiguousOperatorDefaults) {
88
  auto tmp = at::empty({}, at::kCPU);
89
  at::_test_ambiguous_defaults(tmp);
90
  at::_test_ambiguous_defaults(tmp, 1);
91
  at::_test_ambiguous_defaults(tmp, 1, 1);
92
  at::_test_ambiguous_defaults(tmp, 2, "2");
93
}
94

95
int64_t get_first_element(c10::OptionalIntArrayRef arr) {
96
  return arr.value()[0];
97
}
98

99
TEST(OptionalArrayRefTest, DanglingPointerFix) {
100
  // Ensure that the converting constructor of `OptionalArrayRef` does not
101
  // create a dangling pointer when given a single value
102
  ASSERT_TRUE(get_first_element(300) == 300);
103
  ASSERT_TRUE(get_first_element({400}) == 400);
104
}
105

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.