google-research

Форк
0
149 строк · 3.8 Кб
1
// Copyright 2024 The Google Research Authors.
2
//
3
// Licensed under the Apache License, Version 2.0 (the "License");
4
// you may not use this file except in compliance with the License.
5
// You may obtain a copy of the License at
6
//
7
//     http://www.apache.org/licenses/LICENSE-2.0
8
//
9
// Unless required by applicable law or agreed to in writing, software
10
// distributed under the License is distributed on an "AS IS" BASIS,
11
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
// See the License for the specific language governing permissions and
13
// limitations under the License.
14

15
#include "fec_cache.h"
16

17
#include <utility>
18

19
#include "executor.h"
20
#include "fec_hashing.h"
21

22
namespace automl_zero {
23

24
using ::std::make_pair;
25
using ::std::vector;
26
using K = LRUCache::K;
27
using V = LRUCache::V;
28

29
LRUCache::LRUCache(IntegerT max_size)
30
    : max_size_(max_size) {
31
  CHECK_GT(max_size, 1);
32
}
33

34
V* LRUCache::Insert(const K key, const V& value) {
35
  // If already inserted, erase it.
36
  MapIterator found = map_.find(key);
37
  if (found != map_.end()) EraseImpl(found);
38
  V* inserted = InsertImpl(key, value);
39
  MaybeResize();
40
  return inserted;
41
}
42

43
const V* LRUCache::Lookup(const K key) {
44
  MapIterator found = map_.find(key);
45
  if (found == map_.end()) {
46
    // If not found, return nullptr.
47
    return nullptr;
48
  } else {
49
    // If found, return it.
50
    return &found->second->second;
51
  }
52
}
53

54
V* LRUCache::MutableLookup(const K key) {
55
  MapIterator found = map_.find(key);
56
  if (found == map_.end()) {
57
    // If not found, return nullptr.
58
    return nullptr;
59
  } else {
60
    // If found, move it to the front and return it.
61
    const V value = found->second->second;
62
    EraseImpl(found);
63
    return InsertImpl(key, value);
64
  }
65
}
66

67
void LRUCache::Erase(const K key) {
68
  MapIterator found = map_.find(key);
69
  CHECK(found != map_.end());
70
  EraseImpl(found);
71
}
72

73
void LRUCache::Clear() {
74
  map_.clear();
75
  list_.clear();
76
}
77

78
void LRUCache::EraseImpl(MapIterator it) {
79
  list_.erase(it->second);
80
  map_.erase(it);
81
}
82

83
V* LRUCache::InsertImpl(const K key, const V& value) {
84
  list_.push_front(make_pair(key, value));
85
  ListIterator pushed = list_.begin();
86
  map_.insert(make_pair(key, pushed));
87
  return &pushed->second;
88
}
89

90
void LRUCache::MaybeResize() {
91
  // Keep within size limit.
92
  while (list_.size() > max_size_) {
93
    // Erase last element.
94
    const K erasing = list_.back().first;
95
    list_.pop_back();
96
    map_.erase(erasing);
97
  }
98
}
99

100
FECCache::FECCache(const FECSpec& spec)
101
    : spec_(spec), cache_(spec_.cache_size()) {
102
  CHECK_GT(spec_.num_train_examples(), 0);
103
  CHECK_GT(spec_.num_valid_examples(), 0);
104
  CHECK_GT(spec_.cache_size(), 1);
105
  CHECK(spec_.forget_every() == 0 || spec_.forget_every() > 1);
106
}
107

108
size_t FECCache::Hash(
109
    const vector<double>& train_errors,
110
    const vector<double>& valid_errors,
111
    const IntegerT dataset_index, const IntegerT num_train_examples) {
112
  return WellMixedHash(train_errors, valid_errors, dataset_index,
113
                       num_train_examples);
114
}
115

116
std::pair<double, bool> FECCache::Find(const size_t hash) {
117
  CachedEvaluation* cached = cache_.MutableLookup(hash);
118
  if (cached == nullptr) {
119
    return make_pair(kMinFitness, false);
120
  } else {
121
    const double fitness = cached->fitness;
122
    ++cached->count;
123
    if (spec_.forget_every() != 0 && cached->count >= spec_.forget_every()) {
124
      cache_.Erase(hash);
125
    }
126
    return make_pair(fitness, true);
127
  }
128
}
129

130
void FECCache::InsertOrDie(
131
    const size_t hash, const double fitness) {
132
  CHECK(cache_.Lookup(hash) == nullptr);
133
  CachedEvaluation* inserted = cache_.Insert(hash, CachedEvaluation(fitness));
134
  CHECK(inserted != nullptr);
135
}
136

137
void FECCache::Clear() {
138
  cache_.Clear();
139
}
140

141
IntegerT FECCache::NumTrainExamples() const {
142
  return spec_.num_train_examples();
143
}
144

145
IntegerT FECCache::NumValidExamples() const {
146
  return spec_.num_valid_examples();
147
}
148

149
}  // namespace automl_zero
150

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.