pytorch

Форк
0
131 строка · 3.5 Кб
1
#include <torch/csrc/jit/mobile/train/optim/sgd.h>
2

3
#include <torch/types.h>
4
#include <torch/utils.h>
5

6
#include <ATen/ATen.h>
7

8
#include <functional>
9

10
namespace torch {
11
namespace jit {
12
namespace mobile {
13

14
bool SGDParamGroup::has_options() const {
15
  return options_ != nullptr;
16
}
17

18
SGDOptions& SGDParamGroup::options() {
19
  TORCH_CHECK(has_options());
20
  return *options_.get();
21
}
22

23
const SGDOptions& SGDParamGroup::options() const {
24
  TORCH_CHECK(has_options());
25
  return *options_.get();
26
}
27

28
void SGDParamGroup::set_options(std::unique_ptr<SGDOptions> options) {
29
  options_ = std::move(options);
30
}
31

32
std::vector<Tensor>& SGDParamGroup::params() {
33
  return params_;
34
}
35

36
const std::vector<Tensor>& SGDParamGroup::params() const {
37
  return params_;
38
}
39

40
SGDOptions::SGDOptions(double lr) : lr_(lr) {}
41

42
bool operator==(const SGDOptions& lhs, const SGDOptions& rhs) {
43
  return (lhs.lr() == rhs.lr()) && (lhs.momentum() == rhs.momentum()) &&
44
      (lhs.dampening() == rhs.dampening()) &&
45
      (lhs.weight_decay() == rhs.weight_decay()) &&
46
      (lhs.nesterov() == rhs.nesterov());
47
}
48

49
bool operator==(const SGDParamState& lhs, const SGDParamState& rhs) {
50
  return torch::equal(lhs.momentum_buffer(), rhs.momentum_buffer());
51
}
52

53
void SGD::add_param_group(const SGDParamGroup& param_group) {
54
  for (const auto& param : param_group.params()) {
55
    TORCH_CHECK(param.is_leaf(), "can't optimize a non-leaf Tensor");
56
  }
57
  TORCH_INTERNAL_ASSERT(defaults_ != nullptr);
58
  SGDParamGroup param_group_(param_group.params());
59
  if (!param_group.has_options()) {
60
    param_group_.set_options(defaults_->clone());
61
  } else {
62
    param_group_.set_options(param_group.options().clone());
63
  }
64
  for (const auto& p : param_group_.params()) {
65
    TORCH_CHECK(
66
        state_.count(p.unsafeGetTensorImpl()) == 0,
67
        "some parameters appear in more than one parameter group");
68
  }
69
  param_groups_.emplace_back(std::move(param_group_));
70
}
71

72
void SGD::zero_grad() {
73
  for (auto& group : param_groups_) {
74
    for (auto& p : group.params()) {
75
      if (p.grad().defined()) {
76
        p.grad().detach_();
77
        p.grad().zero_();
78
      }
79
    }
80
  }
81
}
82

83
Tensor SGD::step(const LossClosure& closure) {
84
  NoGradGuard no_grad;
85
  Tensor loss = {};
86
  if (closure != nullptr) {
87
    at::AutoGradMode enable_grad(true);
88
    loss = closure();
89
  }
90
  for (auto& group : param_groups_) {
91
    auto& options = static_cast<SGDOptions&>(group.options());
92
    auto weight_decay = options.weight_decay();
93
    auto momentum = options.momentum();
94
    auto dampening = options.dampening();
95
    auto nesterov = options.nesterov();
96

97
    for (auto& p : group.params()) {
98
      if (!p.grad().defined()) {
99
        continue;
100
      }
101
      auto d_p = p.grad().data();
102
      if (weight_decay != 0) {
103
        d_p = d_p.add(p.data(), weight_decay);
104
      }
105
      if (momentum != 0) {
106
        Tensor buf;
107
        auto param_state = state_.find(p.unsafeGetTensorImpl());
108
        if (param_state == state_.end()) {
109
          buf = torch::clone(d_p).detach();
110
          auto state = std::make_unique<SGDParamState>();
111
          state->momentum_buffer(buf);
112
          state_[p.unsafeGetTensorImpl()] = std::move(state);
113
        } else {
114
          buf = static_cast<SGDParamState&>(*param_state->second)
115
                    .momentum_buffer();
116
          buf.mul_(momentum).add_(d_p, 1 - dampening);
117
        }
118
        if (nesterov) {
119
          d_p = d_p.add(buf, momentum);
120
        } else {
121
          d_p = buf;
122
        }
123
      }
124
      p.data().add_(d_p, -1 * options.lr());
125
    }
126
  }
127
  return loss;
128
}
129
} // namespace mobile
130
} // namespace jit
131
} // namespace torch
132

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.