google-research
149 строк · 3.8 Кб
1// Copyright 2024 The Google Research Authors.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7// http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15#include "fec_cache.h"
16
17#include <utility>
18
19#include "executor.h"
20#include "fec_hashing.h"
21
22namespace automl_zero {
23
24using ::std::make_pair;
25using ::std::vector;
26using K = LRUCache::K;
27using V = LRUCache::V;
28
29LRUCache::LRUCache(IntegerT max_size)
30: max_size_(max_size) {
31CHECK_GT(max_size, 1);
32}
33
34V* LRUCache::Insert(const K key, const V& value) {
35// If already inserted, erase it.
36MapIterator found = map_.find(key);
37if (found != map_.end()) EraseImpl(found);
38V* inserted = InsertImpl(key, value);
39MaybeResize();
40return inserted;
41}
42
43const V* LRUCache::Lookup(const K key) {
44MapIterator found = map_.find(key);
45if (found == map_.end()) {
46// If not found, return nullptr.
47return nullptr;
48} else {
49// If found, return it.
50return &found->second->second;
51}
52}
53
54V* LRUCache::MutableLookup(const K key) {
55MapIterator found = map_.find(key);
56if (found == map_.end()) {
57// If not found, return nullptr.
58return nullptr;
59} else {
60// If found, move it to the front and return it.
61const V value = found->second->second;
62EraseImpl(found);
63return InsertImpl(key, value);
64}
65}
66
67void LRUCache::Erase(const K key) {
68MapIterator found = map_.find(key);
69CHECK(found != map_.end());
70EraseImpl(found);
71}
72
73void LRUCache::Clear() {
74map_.clear();
75list_.clear();
76}
77
78void LRUCache::EraseImpl(MapIterator it) {
79list_.erase(it->second);
80map_.erase(it);
81}
82
83V* LRUCache::InsertImpl(const K key, const V& value) {
84list_.push_front(make_pair(key, value));
85ListIterator pushed = list_.begin();
86map_.insert(make_pair(key, pushed));
87return &pushed->second;
88}
89
90void LRUCache::MaybeResize() {
91// Keep within size limit.
92while (list_.size() > max_size_) {
93// Erase last element.
94const K erasing = list_.back().first;
95list_.pop_back();
96map_.erase(erasing);
97}
98}
99
100FECCache::FECCache(const FECSpec& spec)
101: spec_(spec), cache_(spec_.cache_size()) {
102CHECK_GT(spec_.num_train_examples(), 0);
103CHECK_GT(spec_.num_valid_examples(), 0);
104CHECK_GT(spec_.cache_size(), 1);
105CHECK(spec_.forget_every() == 0 || spec_.forget_every() > 1);
106}
107
108size_t FECCache::Hash(
109const vector<double>& train_errors,
110const vector<double>& valid_errors,
111const IntegerT dataset_index, const IntegerT num_train_examples) {
112return WellMixedHash(train_errors, valid_errors, dataset_index,
113num_train_examples);
114}
115
116std::pair<double, bool> FECCache::Find(const size_t hash) {
117CachedEvaluation* cached = cache_.MutableLookup(hash);
118if (cached == nullptr) {
119return make_pair(kMinFitness, false);
120} else {
121const double fitness = cached->fitness;
122++cached->count;
123if (spec_.forget_every() != 0 && cached->count >= spec_.forget_every()) {
124cache_.Erase(hash);
125}
126return make_pair(fitness, true);
127}
128}
129
130void FECCache::InsertOrDie(
131const size_t hash, const double fitness) {
132CHECK(cache_.Lookup(hash) == nullptr);
133CachedEvaluation* inserted = cache_.Insert(hash, CachedEvaluation(fitness));
134CHECK(inserted != nullptr);
135}
136
137void FECCache::Clear() {
138cache_.Clear();
139}
140
141IntegerT FECCache::NumTrainExamples() const {
142return spec_.num_train_examples();
143}
144
145IntegerT FECCache::NumValidExamples() const {
146return spec_.num_valid_examples();
147}
148
149} // namespace automl_zero
150