pytorch

Форк
0
/
init.cpp 
131 строка · 4.2 Кб
1
#include <gtest/gtest.h>
2

3
#include <c10/util/irange.h>
4
#include <torch/torch.h>
5

6
#include <test/cpp/api/init_baseline.h>
7
#include <test/cpp/api/support.h>
8

9
#include <functional>
10
#include <vector>
11

12
void check_exact_values(
13
    const std::vector<torch::Tensor>& parameters,
14
    const std::vector<std::vector<torch::Tensor>>& expected_parameters) {
15
  ASSERT_EQ(parameters.size(), expected_parameters.size());
16

17
  for (const auto i : c10::irange(parameters.size())) {
18
    auto layerParameters = parameters[i];
19
    auto expectedLayerParameters = expected_parameters[i];
20

21
    if (static_cast<size_t>(layerParameters.size(0)) !=
22
        expectedLayerParameters.size()) {
23
      std::cout << "layer #" << i
24
                << " layerParameters size: " << layerParameters.size(0)
25
                << " != "
26
                << " expectedLayerParameters size: "
27
                << expectedLayerParameters.size() << std::endl;
28
      ASSERT_TRUE(false);
29
    }
30

31
    for (const auto p : c10::irange(layerParameters.size(0))) {
32
      // Always compare using double dtype, regardless of the original dtype of
33
      // the tensors
34
      auto tensor = layerParameters[p].to(torch::kFloat64);
35
      auto expectedTensor = expectedLayerParameters[p].to(torch::kFloat64);
36

37
      if (!tensor.allclose(expectedTensor, /*rtol=*/1e-3, /*atol=*/5e-4)) {
38
        std::cout << "layer " << i << ": " << tensor << " != " << expectedTensor
39
                  << " (parameter " << p << ")" << std::endl;
40
        ASSERT_TRUE(false);
41
      }
42
    }
43
  }
44
}
45

46
void check_initializer_against_baseline(
47
    std::function<void(torch::Tensor)> initializer,
48
    std::vector<std::vector<torch::Tensor>> expected) {
49
  torch::manual_seed(0);
50

51
  auto layer1 = torch::nn::Linear(7, 15);
52
  initializer(layer1->weight);
53
  layer1->to(torch::kFloat64);
54

55
  auto layer2 = torch::nn::Linear(15, 15);
56
  initializer(layer2->weight);
57
  layer2->to(torch::kFloat64);
58

59
  auto layer3 = torch::nn::Linear(15, 2);
60
  initializer(layer3->weight);
61
  layer3->to(torch::kFloat64);
62

63
  auto parameters = std::vector<torch::Tensor>{
64
      layer1->weight,
65
      layer2->weight,
66
      layer3->weight,
67
  };
68

69
  check_exact_values(parameters, expected);
70
}
71

72
TEST(InitTest, ProducesPyTorchValues_XavierUniform) {
73
  auto expected = expected_parameters::Xavier_Uniform();
74
  auto initializer = [](torch::Tensor tensor) {
75
    torch::nn::init::xavier_uniform_(tensor);
76
  };
77
  check_initializer_against_baseline(initializer, expected);
78
}
79

80
TEST(InitTest, ProducesPyTorchValues_XavierNormal) {
81
  auto expected = expected_parameters::Xavier_Normal();
82
  auto initializer = [](torch::Tensor tensor) {
83
    torch::nn::init::xavier_normal_(tensor);
84
  };
85
  check_initializer_against_baseline(initializer, expected);
86
}
87

88
TEST(InitTest, ProducesPyTorchValues_KaimingNormal) {
89
  auto expected = expected_parameters::Kaiming_Normal();
90
  auto initializer = [](torch::Tensor tensor) {
91
    torch::nn::init::kaiming_normal_(tensor);
92
  };
93
  check_initializer_against_baseline(initializer, expected);
94
}
95

96
TEST(InitTest, ProducesPyTorchValues_KaimingUniform) {
97
  auto expected = expected_parameters::Kaiming_Uniform();
98
  auto initializer = [](torch::Tensor tensor) {
99
    torch::nn::init::kaiming_uniform_(tensor);
100
  };
101
  check_initializer_against_baseline(initializer, expected);
102
}
103

104
TEST(InitTest, CanInitializeTensorThatRequiresGrad) {
105
  auto tensor = torch::empty({3, 4}, torch::requires_grad());
106
  ASSERT_THROWS_WITH(
107
      tensor.fill_(1),
108
      "a leaf Variable that requires grad "
109
      "is being used in an in-place operation");
110
  ASSERT_EQ(torch::nn::init::ones_(tensor).sum().item<int32_t>(), 12);
111
}
112

113
TEST(InitTest, CalculateGainWithTanh) {
114
  double gain = torch::nn::init::calculate_gain(torch::kTanh);
115
  ASSERT_DOUBLE_EQ(gain, 5.0 / 3.0);
116
}
117

118
TEST(InitTest, CalculateGainWithRelu) {
119
  double gain = torch::nn::init::calculate_gain(torch::kReLU);
120
  ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0));
121
}
122

123
TEST(InitTest, CalculateGainWithLeakyRelu) {
124
  double gain = torch::nn::init::calculate_gain(torch::kLeakyReLU);
125
  ASSERT_DOUBLE_EQ(gain, std::sqrt(2.0 / (1 + pow(0.01, 2))));
126
}
127

128
TEST(InitTest, CanInitializeCnnWithOrthogonal) {
129
  torch::nn::Conv2d conv_layer(torch::nn::Conv2dOptions(3, 2, 3).stride(2));
130
  torch::nn::init::orthogonal_(conv_layer->named_parameters()["weight"]);
131
}
132

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.