google-research
270 строк · 9.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""GRPC for proposing perturbations and topologies, and receiving objectives."""
17import sys
18import time
19from absl import logging
20import numpy as np
21import pyglove as pg
22
23from es_enas import util
24
25
26def propose_queries_blackbox_optimizer(config, current_input,
27blackbox_optimizer, iteration):
28"""Proposes perturbations and topology_str's."""
29
30start_time = time.time()
31
32core_hyperparameters = blackbox_optimizer.get_hyperparameters()
33proposed_perturbations = []
34proposed_dnas = []
35requests = []
36
37for i in range(config.total_num_perturbations):
38perturbation = np.random.normal(
39size=(len(current_input))) * config.es_precision_parameter
40proposed_perturbations.append(perturbation)
41
42dna = config.controller.propose_dna()
43topology_str = pg.to_json(dna)
44proposed_dnas.append(dna)
45
46tag = i + 1
47
48request = {
49'current_input': current_input,
50'hyperparameters': core_hyperparameters,
51'perturbation': perturbation,
52'tag': tag,
53'topology_str': topology_str
54}
55
56requests.append(request)
57
58if config.est_type == 'antithetic':
59antiperturbation = -perturbation
60proposed_perturbations.append(antiperturbation)
61
62dna = config.controller.propose_dna()
63topology_str = pg.to_json(dna)
64proposed_dnas.append(dna)
65
66request = {
67'current_input': current_input,
68'hyperparameters': core_hyperparameters,
69'perturbation': antiperturbation,
70'tag': -tag,
71'topology_str': topology_str
72}
73
74requests.append(request)
75for _ in range(config.num_exact_evals):
76null_perturbation = np.zeros_like(current_input)
77dna = config.controller.propose_dna()
78topology_str = pg.to_json(dna)
79proposed_dnas.append(dna)
80
81request = {
82'current_input': current_input,
83'hyperparameters': core_hyperparameters,
84'perturbation': null_perturbation,
85'tag': 0,
86'topology_str': topology_str
87}
88requests.append(request)
89
90end_time = time.time()
91logging.info('Iteration %d, requests proposed in %f seconds', iteration,
92end_time - start_time)
93
94return requests, proposed_perturbations, proposed_dnas
95
96
97def run_step_blackbox_optimizer(config,
98current_input,
99blackbox_optimizer,
100proposed_perturbations,
101finished_dnas,
102results,
103logging_data=None):
104"""Runs training step after collecting result protos."""
105core_hyperparameters = blackbox_optimizer.get_hyperparameters()
106function_values = [0.0] * len(proposed_perturbations)
107rewards_for_controller = []
108perturbations = proposed_perturbations
109evaluation_stats = []
110current_value_exact = 0.0
111current_value_exact_counter = 0
112
113for i in range(len(results)):
114rewards_for_controller.append(results[i]['function_value'])
115tag = results[i]['tag']
116index = 0
117if tag > 0:
118if config.est_type == 'antithetic':
119index = (tag - 1) * 2
120function_values[index] += results[i]['function_value']
121else:
122index = tag - 1
123function_values[index] += results[i]['function_value']
124if tag < 0:
125index = (-tag - 1) * 2 + 1
126function_values[index] += results[i]['function_value']
127if tag == 0:
128current_value_exact += results[i]['function_value']
129current_value_exact_counter += 1
130current_value_exact /= float(current_value_exact_counter)
131
132for result in results:
133evaluation_stat = list(result['evaluation_stat'])
134evaluation_stats.append(evaluation_stat)
135
136function_values_reshaped = np.array(function_values)
137perturbations_reshaped = np.array(perturbations)
138
139logging.info('LIST OF FUNCTION VALUES')
140logging.info(function_values_reshaped)
141
142logging.info('MAX VALUE SEEN CURRENTLY')
143logging.info(np.max(function_values_reshaped))
144
145logging.info('MEAN OF VALUES')
146logging.info(np.mean(function_values_reshaped))
147
148if logging_data is not None:
149iteration = logging_data['iteration']
150best_value = logging_data['best_value']
151iteration = logging_data['iteration']
152best_input = logging_data['best_input']
153best_core_hyperparameters = logging_data['best_core_hyperparameters']
154optimizer_state = blackbox_optimizer.get_state()
155
156if current_value_exact > best_value[0]:
157best_value[0] = current_value_exact
158best_input = current_input
159best_core_hyperparameters = core_hyperparameters
160
161# Writing logs.
162if iteration % config.log_frequency == 0:
163util.log_row(config.params_file, current_input)
164util.log_row(config.best_params_file, best_input)
165util.log_row(config.best_core_hyperparameters_file,
166best_core_hyperparameters)
167util.log_row(config.best_value_file, best_value)
168util.log_row(config.optimizer_internal_state_file, optimizer_state)
169util.log_row(config.current_values_list_file, [current_value_exact])
170util.log_row(config.best_values_list_file, [best_value[0]])
171util.log_row(config.fvalues_file, function_values_reshaped)
172util.log_row(config.iteration_file, [iteration])
173
174print('Current exact value estimate:')
175print(current_value_exact)
176sys.stdout.flush()
177
178new_current_input = blackbox_optimizer.run_step(perturbations_reshaped,
179function_values_reshaped,
180current_input,
181current_value_exact)
182config.controller.collect_rewards_and_train(rewards_for_controller,
183finished_dnas)
184
185evaluation_stats_reduced = [sum(x) for x in zip(*evaluation_stats)]
186blackbox_optimizer.update_state(evaluation_stats_reduced)
187
188return [True, new_current_input]
189
190
191def run_step_rpc_blackbox_optimizer(config,
192current_input,
193blackbox_optimizer,
194workers,
195iteration,
196best_input,
197best_core_hyperparameters,
198best_value,
199log_bool=False):
200"""Handles the RPC communication in collecting results."""
201requests, proposed_perturbations, proposed_dnas = propose_queries_blackbox_optimizer(
202config, current_input, blackbox_optimizer, iteration)
203
204finished_dnas = []
205
206results = []
207futures = []
208num_worker_failures = 0
209for stub, request in zip(workers, requests):
210future = stub.EvaluateBlackboxInput.future(request)
211futures.append(future)
212start = time.time()
213for w, future in enumerate(futures):
214try:
215results.append(future.result())
216finished_dnas.append(proposed_dnas[w])
217except: # pylint: disable=bare-except
218print('RPC error caught in collecting results !')
219num_worker_failures += 1
220logging.info('worker failed ID: ')
221logging.info(w)
222
223end = time.time()
224print('Responds received in time: [in sec].')
225print(end - start)
226sys.stdout.flush()
227if float(num_worker_failures) > config.critical * float(len(workers)):
228return [False, current_input]
229
230if log_bool:
231logging_data = {
232'best_value': best_value,
233'iteration': iteration,
234'best_input': best_input,
235'best_core_hyperparameters': best_core_hyperparameters
236}
237else:
238logging_data = None
239
240return run_step_blackbox_optimizer(config, current_input, blackbox_optimizer,
241proposed_perturbations, finished_dnas,
242results, logging_data)
243
244
245def run_optimization(config,
246blackbox_optimizer,
247init_current_input,
248init_best_input,
249init_best_core_hyperparameters,
250init_best_value,
251init_iteration,
252workers,
253log_bool=False):
254"""Runs entire optimization procedure."""
255current_input = init_current_input
256best_input = init_best_input
257best_core_hyperparameters = init_best_core_hyperparameters
258best_value = [init_best_value]
259iteration = init_iteration
260
261while True:
262print(iteration)
263sys.stdout.flush()
264success, current_input = run_step_rpc_blackbox_optimizer(
265config, current_input, blackbox_optimizer, workers, iteration,
266best_input, best_core_hyperparameters, best_value, log_bool)
267if success:
268iteration += 1
269if iteration == config.nb_iterations:
270break
271