google-research
186 строк · 6.0 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "algorithm.h"16
17#include <sstream>18#include <string>19#include <vector>20
21#include "definitions.h"22#include "instruction.h"23#include "random_generator.h"24#include "absl/flags/flag.h"25
26namespace automl_zero {27
28using ::std::istringstream; // NOLINT29using ::std::make_shared; // NOLINT30using ::std::ostream; // NOLINT31using ::std::ostringstream; // NOLINT32using ::std::shared_ptr; // NOLINT33using ::std::string; // NOLINT34using ::std::stringstream; // NOLINT35using ::std::vector; // NOLINT36
37Algorithm::Algorithm(const SerializedAlgorithm& checkpoint_algorithm) {38this->FromProto(checkpoint_algorithm);39}
40
41inline void ShallowCopyComponentFunction(42const vector<shared_ptr<const Instruction>>& src,43vector<shared_ptr<const Instruction>>* dest) {44dest->reserve(src.size());45dest->clear();46for (const shared_ptr<const Instruction>& src_instr : src) {47dest->emplace_back(src_instr);48}49}
50
51Algorithm::Algorithm(const Algorithm& other) {52ShallowCopyComponentFunction(other.setup_, &this->setup_);53ShallowCopyComponentFunction(other.predict_, &this->predict_);54ShallowCopyComponentFunction(other.learn_, &this->learn_);55}
56
57Algorithm& Algorithm::operator=(const Algorithm& other) {58if (&other != this) {59ShallowCopyComponentFunction(other.setup_, &this->setup_);60ShallowCopyComponentFunction(other.predict_, &this->predict_);61ShallowCopyComponentFunction(other.learn_, &this->learn_);62}63return *this;64}
65
66Algorithm::Algorithm(Algorithm&& other) {67setup_ = std::move(other.setup_);68predict_ = std::move(other.predict_);69learn_ = std::move(other.learn_);70}
71
72Algorithm& Algorithm::operator=(Algorithm&& other) {73if (&other != this) {74setup_ = std::move(other.setup_);75predict_ = std::move(other.predict_);76learn_ = std::move(other.learn_);77}78return *this;79}
80
81inline bool IsComponentFunctionEqual(82const vector<shared_ptr<const Instruction>>& component_function1,83const vector<shared_ptr<const Instruction>>& component_function2) {84if (component_function1.size() != component_function2.size()) {85return false;86}87vector<shared_ptr<const Instruction>>::const_iterator instruction1_it =88component_function1.begin();89for (const shared_ptr<const Instruction>& instruction2 :90component_function2) {91if (*instruction2 != **instruction1_it) return false;92++instruction1_it;93}94CHECK(instruction1_it == component_function1.end());95return true;96}
97
98bool Algorithm::operator==(const Algorithm& other) const {99if (!IsComponentFunctionEqual(setup_, other.setup_)) return false;100if (!IsComponentFunctionEqual(predict_, other.predict_)) return false;101if (!IsComponentFunctionEqual(learn_, other.learn_)) return false;102return true;103}
104
105string Algorithm::ToReadable() const {106ostringstream stream;107stream << "def Setup():" << std::endl;108for (const shared_ptr<const Instruction>& instruction : setup_) {109stream << instruction->ToString();110}111stream << "def Predict():" << std::endl;112for (const shared_ptr<const Instruction>& instruction : predict_) {113stream << instruction->ToString();114}115stream << "def Learn():" << std::endl;116for (const shared_ptr<const Instruction>& instruction : learn_) {117stream << instruction->ToString();118}119return stream.str();120}
121
122SerializedAlgorithm Algorithm::ToProto() const {123SerializedAlgorithm checkpoint_algorithm;124for (const shared_ptr<const Instruction>& instr : setup_) {125*checkpoint_algorithm.add_setup_instructions() = instr->Serialize();126}127for (const shared_ptr<const Instruction>& instr : predict_) {128*checkpoint_algorithm.add_predict_instructions() = instr->Serialize();129}130for (const shared_ptr<const Instruction>& instr : learn_) {131*checkpoint_algorithm.add_learn_instructions() = instr->Serialize();132}133return checkpoint_algorithm;134}
135
136void Algorithm::FromProto(const SerializedAlgorithm& checkpoint_algorithm) {137setup_.reserve(checkpoint_algorithm.setup_instructions_size());138setup_.clear();139for (const SerializedInstruction& checkpoint_instruction :140checkpoint_algorithm.setup_instructions()) {141setup_.emplace_back(142make_shared<const Instruction>(checkpoint_instruction));143}144
145predict_.reserve(checkpoint_algorithm.predict_instructions_size());146predict_.clear();147for (const SerializedInstruction& checkpoint_instruction :148checkpoint_algorithm.predict_instructions()) {149predict_.emplace_back(150make_shared<const Instruction>(checkpoint_instruction));151}152
153learn_.reserve(checkpoint_algorithm.learn_instructions_size());154learn_.clear();155for (const SerializedInstruction& checkpoint_instruction :156checkpoint_algorithm.learn_instructions()) {157learn_.emplace_back(158make_shared<const Instruction>(checkpoint_instruction));159}160}
161
162const vector<shared_ptr<const Instruction>>& Algorithm::ComponentFunction(163const ComponentFunctionT component_function_type) const {164switch (component_function_type) {165case kSetupComponentFunction:166return setup_;167case kPredictComponentFunction:168return predict_;169case kLearnComponentFunction:170return learn_;171}172}
173
174vector<shared_ptr<const Instruction>>* Algorithm::MutableComponentFunction(175const ComponentFunctionT component_function_type) {176switch (component_function_type) {177case kSetupComponentFunction:178return &setup_;179case kPredictComponentFunction:180return &predict_;181case kLearnComponentFunction:182return &learn_;183}184}
185
186} // namespace automl_zero187