1
#include <gtest/gtest.h>
3
#include <torch/torch.h>
5
#include <test/cpp/api/support.h>
9
using namespace torch::test;
11
void torch_warn_once_A() {
12
TORCH_WARN_ONCE("warn once");
15
void torch_warn_once_B() {
16
TORCH_WARN_ONCE("warn something else once");
20
TORCH_WARN("warn multiple times");
23
TEST(UtilsTest, WarnOnce) {
25
WarningCapture warnings;
32
ASSERT_EQ(count_substr_occurrences(warnings.str(), "warn once"), 1);
34
count_substr_occurrences(warnings.str(), "warn something else once"),
38
WarningCapture warnings;
45
count_substr_occurrences(warnings.str(), "warn multiple times"), 3);
49
TEST(NoGradTest, SetsGradModeCorrectly) {
50
torch::manual_seed(0);
51
torch::NoGradGuard guard;
52
torch::nn::Linear model(5, 2);
53
auto x = torch::randn({10, 5}, torch::requires_grad());
54
auto y = model->forward(x);
55
torch::Tensor s = y.sum();
57
// Mimicking python API behavior:
60
"element 0 of tensors does not require grad and does not have a grad_fn")
63
struct AutogradTest : torch::test::SeedingFixture {
65
x = torch::randn({3, 3}, torch::requires_grad());
66
y = torch::randn({3, 3});
69
torch::Tensor x, y, z;
72
TEST_F(AutogradTest, CanTakeDerivatives) {
73
z.backward(torch::ones_like(z));
74
ASSERT_TRUE(x.grad().allclose(y));
77
TEST_F(AutogradTest, CanTakeDerivativesOfZeroDimTensors) {
79
ASSERT_TRUE(x.grad().allclose(y));
82
TEST_F(AutogradTest, CanPassCustomGradientInputs) {
83
z.sum().backward(torch::ones({}) * 2);
84
ASSERT_TRUE(x.grad().allclose(y * 2));
87
TEST(UtilsTest, AmbiguousOperatorDefaults) {
88
auto tmp = at::empty({}, at::kCPU);
89
at::_test_ambiguous_defaults(tmp);
90
at::_test_ambiguous_defaults(tmp, 1);
91
at::_test_ambiguous_defaults(tmp, 1, 1);
92
at::_test_ambiguous_defaults(tmp, 2, "2");
95
int64_t get_first_element(c10::OptionalIntArrayRef arr) {
96
return arr.value()[0];
99
TEST(OptionalArrayRefTest, DanglingPointerFix) {
100
// Ensure that the converting constructor of `OptionalArrayRef` does not
101
// create a dangling pointer when given a single value
102
ASSERT_TRUE(get_first_element(300) == 300);
103
ASSERT_TRUE(get_first_element({400}) == 400);