google-research

Форк
0
/
simulation.py 
256 строк · 9.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""GON experimental code on Rosenbrock and Griewank data."""
17

18
from __future__ import absolute_import
19
from __future__ import division
20
from __future__ import print_function
21

22
import random
23
from absl import app
24
from absl import flags
25
import numpy as np
26
import tensorflow as tf
27
import tensorflow_lattice as tfl
28

29
_CGON_OR_GON = flags.DEFINE_string("cgon_or_gon", "BOTH", "CGON, GON or BOTH")
30
# Data generation hparams
31
_NUM_EXAMPLES = flags.DEFINE_integer("num_examples", 10000,
32
                                     "number of training examples.")
33
_INPUT_DIM = flags.DEFINE_integer("input_dim", 16, "number input dimensions.")
34
_BOUND = flags.DEFINE_float("bound", 2.0, "(soft) bound for sampling.")
35
_NOISE_SIGNAL_RATIO = flags.DEFINE_float(
36
    "noise_signal_ratio", 0.25,
37
    "standard deviation of the noise as a factor of the target function value "
38
    "in the target function.")
39
# GON/CGON hparams
40
_LATTICE_DIM = flags.DEFINE_integer("lattice_dim", 3,
41
                                    "number GON lattice input dimensions.")
42
# Training hparams
43
_STEPS = flags.DEFINE_integer("steps", 100000, "number of training steps.")
44

45

46
def sample_from_uniform_square(count=1000, dim=2):
47
  points = np.random.uniform(low=-_BOUND, high=_BOUND, size=(count, dim))
48
  return points
49

50

51
def rosenbrock(points, noise_signal_ratio=0.0):
52
  summation = 0
53
  for i in range(_INPUT_DIM - 1):
54
    summation += ((1.0 - points[:, i])**2 + 100.0 *
55
                  (points[:, i + 1] - points[:, i]**2)**2)
56
  noise = np.random.normal(
57
      scale=abs(summation * noise_signal_ratio), size=len(points))
58
  return summation + noise
59

60

61
def griewank(points, noise_signal_ratio=0.0):
62
  summation = 0
63
  product = 1
64
  for i in range(_INPUT_DIM):
65
    summation += (points[:, i] - 1.0)**2
66
    product *= np.cos((points[:, i] - 1.0) / np.sqrt(i + 1))
67
  y = 1 + summation / 4000 - product
68
  noise = np.random.normal(scale=abs(y * noise_signal_ratio), size=len(points))
69
  return y + noise
70

71

72
def parabola(points, noise_signal_ratio=0.0):
73
  summation = 0
74
  for i in range(_INPUT_DIM):
75
    summation += (points[:, i] - 1.0)**2
76
  noise = np.random.normal(
77
      scale=abs(summation * noise_signal_ratio), size=len(points))
78
  return summation + noise
79

80

81
# Global Optimization Networks (GON):
82
def build_rtl_gon(lattice_dim):
83
  """Build GON model."""
84
  input_layer = tf.keras.layers.Input(shape=(_INPUT_DIM,))
85
  calibrated_output = tfl.layers.PWLCalibration(
86
      input_keypoints=np.linspace(-_BOUND, _BOUND, 10),
87
      units=_INPUT_DIM,
88
      output_min=-1.0,
89
      output_max=1.0,
90
      clamp_min=True,
91
      clamp_max=True,
92
      monotonicity="increasing",
93
      name="input_calibration",
94
  )(
95
      input_layer)
96
  lattice_inputs = []
97
  for _ in range(_INPUT_DIM):
98
    indices = random.sample(range(_INPUT_DIM), lattice_dim)
99
    lattice_inputs.append(tf.gather(calibrated_output, indices, axis=1) + 1.0)
100
  lattice_input = tf.stack(lattice_inputs, axis=1)
101
  lattice_output_layer = tfl.layers.Lattice(
102
      lattice_sizes=[3] * lattice_dim,
103
      units=_INPUT_DIM,
104
      joint_unimodalities=(list(range(lattice_dim)), "valley"),
105
      kernel_initializer="random_uniform",
106
  )(
107
      lattice_input)
108
  output_layer = tf.keras.layers.Dense(
109
      units=1, kernel_constraint=tf.keras.constraints.NonNeg())(
110
          lattice_output_layer)
111
  keras_model = tf.keras.models.Model(
112
      inputs=[input_layer], outputs=output_layer)
113
  keras_model.compile(
114
      loss="mean_squared_error", optimizer=tf.keras.optimizers.Adam())
115
  return keras_model
116

117

118
# Conditional Global Optimization Networks (CGON):
119
def build_rtl_cgon(lattice_dim, gon_dim):
120
  """Build GON model."""
121
  input_layer = tf.keras.layers.Input(shape=(_INPUT_DIM,))
122
  nongon_dim = _INPUT_DIM - gon_dim
123
  gon_input_layer, nongon_input_layer = tf.split(
124
      input_layer, [gon_dim, nongon_dim], axis=1)
125
  gon_calibrated_output = tfl.layers.PWLCalibration(
126
      input_keypoints=np.linspace(-_BOUND, _BOUND, 10),
127
      units=gon_dim,
128
      monotonicity="increasing",
129
      output_min=-1.0,
130
      output_max=1.0,
131
      clamp_min=True,
132
      clamp_max=True,
133
      name="gon_calibration",
134
  )(
135
      gon_input_layer)
136
  gon_rtl_input_layer = []
137
  for i in range(gon_dim):
138
    nongon_calibrated_output = tf.reduce_mean(
139
        tfl.layers.PWLCalibration(
140
            input_keypoints=np.linspace(-_BOUND, _BOUND, 10),
141
            units=nongon_dim,
142
            output_min=-1.0,
143
            output_max=1.0,
144
            name="nongon_pwl_" + str(i))(nongon_input_layer),
145
        axis=1,
146
        keepdims=True)
147
    gon_rtl_input_layer.append(nongon_calibrated_output +
148
                               tf.gather(gon_calibrated_output, [i], axis=1))
149
  gon_rtl_input_layer = tf.tanh(tf.concat(gon_rtl_input_layer, axis=1)) + 1.0
150
  lattice_inputs = []
151
  for _ in range(gon_dim):
152
    gon_indices = random.sample(range(gon_dim), lattice_dim)
153
    gon_lattice_input = tf.gather(gon_rtl_input_layer, gon_indices, axis=1)
154
    lattice_inputs.append(gon_lattice_input)
155
  lattice_input = tf.stack(lattice_inputs, axis=1)
156
  lattice_output_layer = tfl.layers.Lattice(
157
      lattice_sizes=[3] * lattice_dim,
158
      units=gon_dim,
159
      joint_unimodalities=(list(range(lattice_dim)), "valley"),
160
      kernel_initializer="random_uniform",
161
  )(
162
      lattice_input)
163
  output_layer = tf.keras.layers.Dense(
164
      units=1, kernel_constraint=tf.keras.constraints.NonNeg())(
165
          lattice_output_layer)
166
  keras_model = tf.keras.models.Model(
167
      inputs=[input_layer], outputs=output_layer)
168
  keras_model.compile(
169
      loss="mean_squared_error", optimizer=tf.keras.optimizers.Adam())
170
  return keras_model
171

172

173
def get_optim_from_cal(calibrator_weights, target=0.0):
174
  """Get Argmin Through Inverting the Calibrators."""
175
  keypoints_x = np.linspace(-_BOUND, _BOUND, 10)
176
  keypoints_y = np.cumsum(calibrator_weights)
177
  if keypoints_y[0] > target:
178
    return keypoints_x[0]
179
  for i in range(len(keypoints_x) - 1):
180
    if keypoints_y[i] < target and keypoints_y[i + 1] > target:
181
      w = (keypoints_y[i + 1] - target) / (keypoints_y[i + 1] - keypoints_y[i])
182
      return keypoints_x[i + 1] - w * (keypoints_x[i + 1] - keypoints_x[i])
183
  return keypoints_x[-1]
184

185

186
def main(_):
187
  data_generation_function = griewank  # rosenbrock / griewank / parabola
188
  # The global minimizer is at (1.0, 1.0, ..., 1.0).
189
  argmin_true = 1.0
190

191
  # Generate Data
192
  training_inputs = sample_from_uniform_square(
193
      count=_NUM_EXAMPLES, dim=_INPUT_DIM)
194
  labels = (
195
      data_generation_function(
196
          training_inputs, noise_signal_ratio=_NOISE_SIGNAL_RATIO) /
197
      data_generation_function(np.array([[-_BOUND] * _INPUT_DIM])))
198

199
  # Simulation for GON
200
  if _CGON_OR_GON != "CGON":
201
    # Global Optimization Networks
202
    keras_model_gon = build_rtl_gon(lattice_dim=_LATTICE_DIM)
203
    keras_model_gon.fit(
204
        training_inputs,
205
        labels,
206
        batch_size=100,
207
        epochs=int(_STEPS * 100 / _NUM_EXAMPLES),
208
        verbose=0,
209
    )
210
    result_gon = []
211
    for i in range(_INPUT_DIM):
212
      result_gon.append(
213
          get_optim_from_cal(
214
              keras_model_gon.get_layer("input_calibration").get_weights()[0]
215
              [:, i]))
216
    # f(argmin_hat)
217
    print("[metric] GON_val=" +
218
          str(data_generation_function(np.array([result_gon]))[0]))
219
    # ||argmin_hat - argmin_true||^2
220
    print("[metric] GON_dist=" + str(
221
        np.sum([(a - argmin_true) * (a - argmin_true) for a in result_gon])))
222

223
  # Simulation for CGON
224
  if _CGON_OR_GON != "GON":
225
    # Input features for each example consists of first optimization_dim for
226
    # optimiaztion, and then conditional_dim conditional inputs.
227
    optimization_dim = int(_INPUT_DIM / 4 * 3)
228
    conditional_dim = int(_INPUT_DIM / 4)
229
    # Conditional Global Optimization Networks
230
    keras_model_cgon = build_rtl_cgon(
231
        lattice_dim=_LATTICE_DIM, gon_dim=optimization_dim)
232
    keras_model_cgon.fit(
233
        training_inputs,
234
        labels,
235
        batch_size=100,
236
        epochs=int(_STEPS * 100 / _NUM_EXAMPLES),
237
        verbose=0,
238
    )
239
    result_cgon = []
240
    for i in range(optimization_dim):
241
      target = -tf.reduce_mean(
242
          keras_model_cgon.get_layer("nongon_pwl_" + str(i))(tf.zeros(
243
              shape=[1, conditional_dim], dtype=tf.float32))).numpy()
244
      result_cgon.append(
245
          get_optim_from_cal(
246
              keras_model_cgon.get_layer("gon_calibration").get_weights()[0][:,
247
                                                                             i],
248
              target))
249
    result_cgon = result_cgon + [0.0] * conditional_dim
250
    # f(argmin_hat)
251
    print("[metric] CGON_val=" +
252
          str(data_generation_function(np.array([result_cgon]))[0]))
253

254

255
if __name__ == "__main__":
256
  app.run(main)
257

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.