google-research

Форк
0
/
es_enas_learner_grpc.py 
270 строк · 9.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""GRPC for proposing perturbations and topologies, and receiving objectives."""
17
import sys
18
import time
19
from absl import logging
20
import numpy as np
21
import pyglove as pg
22

23
from es_enas import util
24

25

26
def propose_queries_blackbox_optimizer(config, current_input,
27
                                       blackbox_optimizer, iteration):
28
  """Proposes perturbations and topology_str's."""
29

30
  start_time = time.time()
31

32
  core_hyperparameters = blackbox_optimizer.get_hyperparameters()
33
  proposed_perturbations = []
34
  proposed_dnas = []
35
  requests = []
36

37
  for i in range(config.total_num_perturbations):
38
    perturbation = np.random.normal(
39
        size=(len(current_input))) * config.es_precision_parameter
40
    proposed_perturbations.append(perturbation)
41

42
    dna = config.controller.propose_dna()
43
    topology_str = pg.to_json(dna)
44
    proposed_dnas.append(dna)
45

46
    tag = i + 1
47

48
    request = {
49
        'current_input': current_input,
50
        'hyperparameters': core_hyperparameters,
51
        'perturbation': perturbation,
52
        'tag': tag,
53
        'topology_str': topology_str
54
    }
55

56
    requests.append(request)
57

58
    if config.est_type == 'antithetic':
59
      antiperturbation = -perturbation
60
      proposed_perturbations.append(antiperturbation)
61

62
      dna = config.controller.propose_dna()
63
      topology_str = pg.to_json(dna)
64
      proposed_dnas.append(dna)
65

66
      request = {
67
          'current_input': current_input,
68
          'hyperparameters': core_hyperparameters,
69
          'perturbation': antiperturbation,
70
          'tag': -tag,
71
          'topology_str': topology_str
72
      }
73

74
      requests.append(request)
75
  for _ in range(config.num_exact_evals):
76
    null_perturbation = np.zeros_like(current_input)
77
    dna = config.controller.propose_dna()
78
    topology_str = pg.to_json(dna)
79
    proposed_dnas.append(dna)
80

81
    request = {
82
        'current_input': current_input,
83
        'hyperparameters': core_hyperparameters,
84
        'perturbation': null_perturbation,
85
        'tag': 0,
86
        'topology_str': topology_str
87
    }
88
    requests.append(request)
89

90
  end_time = time.time()
91
  logging.info('Iteration %d, requests proposed in %f seconds', iteration,
92
               end_time - start_time)
93

94
  return requests, proposed_perturbations, proposed_dnas
95

96

97
def run_step_blackbox_optimizer(config,
98
                                current_input,
99
                                blackbox_optimizer,
100
                                proposed_perturbations,
101
                                finished_dnas,
102
                                results,
103
                                logging_data=None):
104
  """Runs training step after collecting result protos."""
105
  core_hyperparameters = blackbox_optimizer.get_hyperparameters()
106
  function_values = [0.0] * len(proposed_perturbations)
107
  rewards_for_controller = []
108
  perturbations = proposed_perturbations
109
  evaluation_stats = []
110
  current_value_exact = 0.0
111
  current_value_exact_counter = 0
112

113
  for i in range(len(results)):
114
    rewards_for_controller.append(results[i]['function_value'])
115
    tag = results[i]['tag']
116
    index = 0
117
    if tag > 0:
118
      if config.est_type == 'antithetic':
119
        index = (tag - 1) * 2
120
        function_values[index] += results[i]['function_value']
121
      else:
122
        index = tag - 1
123
        function_values[index] += results[i]['function_value']
124
    if tag < 0:
125
      index = (-tag - 1) * 2 + 1
126
      function_values[index] += results[i]['function_value']
127
    if tag == 0:
128
      current_value_exact += results[i]['function_value']
129
      current_value_exact_counter += 1
130
  current_value_exact /= float(current_value_exact_counter)
131

132
  for result in results:
133
    evaluation_stat = list(result['evaluation_stat'])
134
    evaluation_stats.append(evaluation_stat)
135

136
  function_values_reshaped = np.array(function_values)
137
  perturbations_reshaped = np.array(perturbations)
138

139
  logging.info('LIST OF FUNCTION VALUES')
140
  logging.info(function_values_reshaped)
141

142
  logging.info('MAX VALUE SEEN CURRENTLY')
143
  logging.info(np.max(function_values_reshaped))
144

145
  logging.info('MEAN OF VALUES')
146
  logging.info(np.mean(function_values_reshaped))
147

148
  if logging_data is not None:
149
    iteration = logging_data['iteration']
150
    best_value = logging_data['best_value']
151
    iteration = logging_data['iteration']
152
    best_input = logging_data['best_input']
153
    best_core_hyperparameters = logging_data['best_core_hyperparameters']
154
    optimizer_state = blackbox_optimizer.get_state()
155

156
    if current_value_exact > best_value[0]:
157
      best_value[0] = current_value_exact
158
      best_input = current_input
159
      best_core_hyperparameters = core_hyperparameters
160

161
    # Writing logs.
162
    if iteration % config.log_frequency == 0:
163
      util.log_row(config.params_file, current_input)
164
      util.log_row(config.best_params_file, best_input)
165
      util.log_row(config.best_core_hyperparameters_file,
166
                   best_core_hyperparameters)
167
      util.log_row(config.best_value_file, best_value)
168
      util.log_row(config.optimizer_internal_state_file, optimizer_state)
169
      util.log_row(config.current_values_list_file, [current_value_exact])
170
      util.log_row(config.best_values_list_file, [best_value[0]])
171
      util.log_row(config.fvalues_file, function_values_reshaped)
172
      util.log_row(config.iteration_file, [iteration])
173

174
    print('Current exact value estimate:')
175
    print(current_value_exact)
176
    sys.stdout.flush()
177

178
  new_current_input = blackbox_optimizer.run_step(perturbations_reshaped,
179
                                                  function_values_reshaped,
180
                                                  current_input,
181
                                                  current_value_exact)
182
  config.controller.collect_rewards_and_train(rewards_for_controller,
183
                                              finished_dnas)
184

185
  evaluation_stats_reduced = [sum(x) for x in zip(*evaluation_stats)]
186
  blackbox_optimizer.update_state(evaluation_stats_reduced)
187

188
  return [True, new_current_input]
189

190

191
def run_step_rpc_blackbox_optimizer(config,
192
                                    current_input,
193
                                    blackbox_optimizer,
194
                                    workers,
195
                                    iteration,
196
                                    best_input,
197
                                    best_core_hyperparameters,
198
                                    best_value,
199
                                    log_bool=False):
200
  """Handles the RPC communication in collecting results."""
201
  requests, proposed_perturbations, proposed_dnas = propose_queries_blackbox_optimizer(
202
      config, current_input, blackbox_optimizer, iteration)
203

204
  finished_dnas = []
205

206
  results = []
207
  futures = []
208
  num_worker_failures = 0
209
  for stub, request in zip(workers, requests):
210
    future = stub.EvaluateBlackboxInput.future(request)
211
    futures.append(future)
212
  start = time.time()
213
  for w, future in enumerate(futures):
214
    try:
215
      results.append(future.result())
216
      finished_dnas.append(proposed_dnas[w])
217
    except:  # pylint: disable=bare-except
218
      print('RPC error caught in collecting results !')
219
      num_worker_failures += 1
220
      logging.info('worker failed ID: ')
221
      logging.info(w)
222

223
  end = time.time()
224
  print('Responds received in time: [in sec].')
225
  print(end - start)
226
  sys.stdout.flush()
227
  if float(num_worker_failures) > config.critical * float(len(workers)):
228
    return [False, current_input]
229

230
  if log_bool:
231
    logging_data = {
232
        'best_value': best_value,
233
        'iteration': iteration,
234
        'best_input': best_input,
235
        'best_core_hyperparameters': best_core_hyperparameters
236
    }
237
  else:
238
    logging_data = None
239

240
  return run_step_blackbox_optimizer(config, current_input, blackbox_optimizer,
241
                                     proposed_perturbations, finished_dnas,
242
                                     results, logging_data)
243

244

245
def run_optimization(config,
246
                     blackbox_optimizer,
247
                     init_current_input,
248
                     init_best_input,
249
                     init_best_core_hyperparameters,
250
                     init_best_value,
251
                     init_iteration,
252
                     workers,
253
                     log_bool=False):
254
  """Runs entire optimization procedure."""
255
  current_input = init_current_input
256
  best_input = init_best_input
257
  best_core_hyperparameters = init_best_core_hyperparameters
258
  best_value = [init_best_value]
259
  iteration = init_iteration
260

261
  while True:
262
    print(iteration)
263
    sys.stdout.flush()
264
    success, current_input = run_step_rpc_blackbox_optimizer(
265
        config, current_input, blackbox_optimizer, workers, iteration,
266
        best_input, best_core_hyperparameters, best_value, log_bool)
267
    if success:
268
      iteration += 1
269
    if iteration == config.nb_iterations:
270
      break
271

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.