google-research

Форк
0
186 строк · 6.0 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "algorithm.h"
16

17
#include <sstream>
18
#include <string>
19
#include <vector>
20

21
#include "definitions.h"
22
#include "instruction.h"
23
#include "random_generator.h"
24
#include "absl/flags/flag.h"
25

26
namespace automl_zero {
27

28
using ::std::istringstream;  // NOLINT
29
using ::std::make_shared;  // NOLINT
30
using ::std::ostream;  // NOLINT
31
using ::std::ostringstream;  // NOLINT
32
using ::std::shared_ptr;  // NOLINT
33
using ::std::string;  // NOLINT
34
using ::std::stringstream;  // NOLINT
35
using ::std::vector;  // NOLINT
36

37
Algorithm::Algorithm(const SerializedAlgorithm& checkpoint_algorithm) {
38
  this->FromProto(checkpoint_algorithm);
39
}
40

41
inline void ShallowCopyComponentFunction(
42
    const vector<shared_ptr<const Instruction>>& src,
43
    vector<shared_ptr<const Instruction>>* dest) {
44
  dest->reserve(src.size());
45
  dest->clear();
46
  for (const shared_ptr<const Instruction>& src_instr : src) {
47
    dest->emplace_back(src_instr);
48
  }
49
}
50

51
Algorithm::Algorithm(const Algorithm& other) {
52
  ShallowCopyComponentFunction(other.setup_, &this->setup_);
53
  ShallowCopyComponentFunction(other.predict_, &this->predict_);
54
  ShallowCopyComponentFunction(other.learn_, &this->learn_);
55
}
56

57
Algorithm& Algorithm::operator=(const Algorithm& other) {
58
  if (&other != this) {
59
    ShallowCopyComponentFunction(other.setup_, &this->setup_);
60
    ShallowCopyComponentFunction(other.predict_, &this->predict_);
61
    ShallowCopyComponentFunction(other.learn_, &this->learn_);
62
  }
63
  return *this;
64
}
65

66
Algorithm::Algorithm(Algorithm&& other) {
67
  setup_ = std::move(other.setup_);
68
  predict_ = std::move(other.predict_);
69
  learn_ = std::move(other.learn_);
70
}
71

72
Algorithm& Algorithm::operator=(Algorithm&& other) {
73
  if (&other != this) {
74
    setup_ = std::move(other.setup_);
75
    predict_ = std::move(other.predict_);
76
    learn_ = std::move(other.learn_);
77
  }
78
  return *this;
79
}
80

81
inline bool IsComponentFunctionEqual(
82
    const vector<shared_ptr<const Instruction>>& component_function1,
83
    const vector<shared_ptr<const Instruction>>& component_function2) {
84
  if (component_function1.size() != component_function2.size()) {
85
    return false;
86
  }
87
  vector<shared_ptr<const Instruction>>::const_iterator instruction1_it =
88
      component_function1.begin();
89
  for (const shared_ptr<const Instruction>& instruction2 :
90
       component_function2) {
91
    if (*instruction2 != **instruction1_it) return false;
92
    ++instruction1_it;
93
  }
94
  CHECK(instruction1_it == component_function1.end());
95
  return true;
96
}
97

98
bool Algorithm::operator==(const Algorithm& other) const {
99
  if (!IsComponentFunctionEqual(setup_, other.setup_)) return false;
100
  if (!IsComponentFunctionEqual(predict_, other.predict_)) return false;
101
  if (!IsComponentFunctionEqual(learn_, other.learn_)) return false;
102
  return true;
103
}
104

105
string Algorithm::ToReadable() const {
106
  ostringstream stream;
107
  stream << "def Setup():" << std::endl;
108
  for (const shared_ptr<const Instruction>& instruction : setup_) {
109
    stream << instruction->ToString();
110
  }
111
  stream << "def Predict():" << std::endl;
112
  for (const shared_ptr<const Instruction>& instruction : predict_) {
113
    stream << instruction->ToString();
114
  }
115
  stream << "def Learn():" << std::endl;
116
  for (const shared_ptr<const Instruction>& instruction : learn_) {
117
    stream << instruction->ToString();
118
  }
119
  return stream.str();
120
}
121

122
SerializedAlgorithm Algorithm::ToProto() const {
123
  SerializedAlgorithm checkpoint_algorithm;
124
  for (const shared_ptr<const Instruction>& instr : setup_) {
125
    *checkpoint_algorithm.add_setup_instructions() = instr->Serialize();
126
  }
127
  for (const shared_ptr<const Instruction>& instr : predict_) {
128
    *checkpoint_algorithm.add_predict_instructions() = instr->Serialize();
129
  }
130
  for (const shared_ptr<const Instruction>& instr : learn_) {
131
    *checkpoint_algorithm.add_learn_instructions() = instr->Serialize();
132
  }
133
  return checkpoint_algorithm;
134
}
135

136
void Algorithm::FromProto(const SerializedAlgorithm& checkpoint_algorithm) {
137
  setup_.reserve(checkpoint_algorithm.setup_instructions_size());
138
  setup_.clear();
139
  for (const SerializedInstruction& checkpoint_instruction :
140
       checkpoint_algorithm.setup_instructions()) {
141
    setup_.emplace_back(
142
        make_shared<const Instruction>(checkpoint_instruction));
143
  }
144

145
  predict_.reserve(checkpoint_algorithm.predict_instructions_size());
146
  predict_.clear();
147
  for (const SerializedInstruction& checkpoint_instruction :
148
       checkpoint_algorithm.predict_instructions()) {
149
    predict_.emplace_back(
150
        make_shared<const Instruction>(checkpoint_instruction));
151
  }
152

153
  learn_.reserve(checkpoint_algorithm.learn_instructions_size());
154
  learn_.clear();
155
  for (const SerializedInstruction& checkpoint_instruction :
156
       checkpoint_algorithm.learn_instructions()) {
157
    learn_.emplace_back(
158
        make_shared<const Instruction>(checkpoint_instruction));
159
  }
160
}
161

162
const vector<shared_ptr<const Instruction>>& Algorithm::ComponentFunction(
163
    const ComponentFunctionT component_function_type) const {
164
  switch (component_function_type) {
165
    case kSetupComponentFunction:
166
      return setup_;
167
    case kPredictComponentFunction:
168
      return predict_;
169
    case kLearnComponentFunction:
170
      return learn_;
171
  }
172
}
173

174
vector<shared_ptr<const Instruction>>* Algorithm::MutableComponentFunction(
175
    const ComponentFunctionT component_function_type) {
176
  switch (component_function_type) {
177
    case kSetupComponentFunction:
178
      return &setup_;
179
    case kPredictComponentFunction:
180
      return &predict_;
181
    case kLearnComponentFunction:
182
      return &learn_;
183
  }
184
}
185

186
}  // namespace automl_zero
187

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.