google-research

Форк
0
461 строка · 16.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Experimental code for "Distributino Embedding Network for Meta Learning" on OpenML data.
17
"""
18

19
from __future__ import absolute_import
20
from __future__ import division
21
from __future__ import print_function
22

23
import os
24
import random
25
import time
26

27
from absl import app
28
from absl import flags
29
import numpy as np
30
import tensorflow as tf
31
import tensorflow_lattice as tfl
32

33
# Data preparation hparams
34
_DATA_DIRECTORY = flags.DEFINE_string(
35
    "openml_data_directory", None,
36
    "Directory that stores the training data. This dicretory should only"
37
    "include training csv files, and nothing else.")
38
_MAX_NUM_CLASSES = flags.DEFINE_integer("max_num_classes", 2,
39
                                        "max number of classes across tasks")
40
_OPENML_TEST_ID = flags.DEFINE_integer("openml_test_id", 0,
41
                                       "id of openml data to be used as test.")
42
_MISSING_VALUE = flags.DEFINE_float("missing_value", -1.0,
43
                                    "missing value in real data.")
44

45
# Data generation hp arams
46
_NUM_FINE_TUNE = flags.DEFINE_integer("num_fine_tune", 50,
47
                                      "number of fine-tuning examples.")
48
_NUM_INPUTS = flags.DEFINE_integer("num_inputs", 25,
49
                                   "max number input dimensions across tasks.")
50
_PAD_VALUE = flags.DEFINE_integer("pad_value", -10, "value used for padding.")
51
_R = flags.DEFINE_integer("r", 2, "r value in embedding.")
52

53
# Training hparams
54
_PRETRAIN_BATCHES = flags.DEFINE_integer(
55
    "pretrain_batches", 10000, "number of steps to pretrain the model.")
56
_TUNE_EPOCHS = flags.DEFINE_integer("tune_epochs", 10,
57
                                    "number of fine-tuning epochs.")
58
_BATCH_SIZE = flags.DEFINE_integer(
59
    "batch_size", 64, "batch size for pretraining and finetuning.")
60

61
# DEN hparams
62
_NUM_CALIB_KEYS = flags.DEFINE_integer(
63
    "num_calib_keys", 10, "number of keypoints in the calibration layer.")
64
_HIDDEN_LAYER = flags.DEFINE_integer("hidden_layer", 2,
65
                                     "depth of the h, phi and psi functions.")
66
_DISTRIBUTION_REPRESENTATION_DIM = flags.DEFINE_integer(
67
    "distribution_representation_dim", 16, "width of the h function.")
68
_DEEPSETS_LAYER_UNITS = flags.DEFINE_integer("deepsets_layer_units", 10,
69
                                             "width of the phi function.")
70
_OUTPUT_LAYER_UNITS = flags.DEFINE_integer("output_layer_units", 8,
71
                                           "width of the psi function.")
72

73

74
#############################################################################
75
# Utils
76
#############################################################################
77
def load_openml_data():
78
  """Loads data from CNS.
79

80
  Output a dictionary of the form {name: data}. Here data is a list of numpy
81
  arrays with the first column being labels and the rest being features.
82

83
  Returns:
84
    datasets: dictionary. Each element is (name, val) pair where name is the
85
      dataset name and val is a list containing binary classification tasks
86
      within this dataset.
87
    files: list of files in the directory.
88
  """
89
  datasets = dict()
90
  files = os.listdir(_DATA_DIRECTORY.value)
91
  for file_name in files:
92
    with open(_DATA_DIRECTORY.value + file_name, "r") as ff:
93
      task = np.loadtxt(ff, delimiter=",", skiprows=1)
94
      np.random.shuffle(task)
95
      datasets[file_name] = [task]
96
  return datasets, files
97

98

99
def truncate_data(data, indices):
100
  """Truncates data using indices provided."""
101
  truncated_data = []
102
  for task in data:
103
    truncated_data.append(task[indices])
104
  return truncated_data
105

106

107
def compute_metrics(labels, predictions):
108
  """Computes metrics."""
109
  loss = tf.keras.losses.SparseCategoricalCrossentropy()
110
  res = [loss(labels, predictions).numpy()]
111

112
  metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")]
113
  for m in metrics:
114
    _ = m.update_state(labels, predictions)
115
    res.append(m.result().numpy())
116
  return res
117

118

119
def average_metrics(metrics):
120
  """Average metrics."""
121
  avg_metrics = dict()
122
  for name, metric in metrics.items():
123
    avg_metrics[name] = []
124
    for m in metric:
125
      avg_metrics[name].append({
126
          "mean": np.mean(np.array(m), axis=0),
127
          "std": np.std(np.array(m), axis=0)
128
      })
129
  return avg_metrics
130

131

132
def print_metrics(metrics, data_name):
133
  """Print metrics."""
134
  metrics_name = ["loss", "accuracy", "auc"]
135
  for i, metric in enumerate(metrics):
136
    for mean, std, name in zip(metric["mean"], metric["std"], metrics_name):
137
      print(f"[metric] task{i}_{data_name}_{name}_mean={mean}")
138
      print(f"[metric] task{i}_{data_name}_{name}_std={std}")
139

140

141
def pad_features(features, size, axis=1, pad_value=None):
142
  """Pad features."""
143
  if pad_value is None:  # Repeat columns
144
    num = features.shape[axis]
145
    repeat_indices = random.sample(range(num), size - num)
146
    repeat_features = tf.gather(features, repeat_indices, axis=axis)
147
    new_features = tf.concat([features, repeat_features], axis=axis)
148
  else:  # Add padding values
149
    paddings = [[0, 0] for _ in features.shape]
150
    paddings[axis] = [0, size - features.shape[axis]]
151
    new_features = tf.pad(
152
        features, tf.constant(paddings), constant_values=pad_value)
153
  return new_features
154

155

156
def get_pairwise_inputs(inputs):
157
  """Reform inputs to pairwise format."""
158
  # [BATCH_SIZE, NUM_INPUTS] --> [BATCH_SIZE, NUM_INPUTS**2, _R.value]
159
  num_features = inputs.shape[1]
160
  feature = []
161
  np.random.seed(seed=np.mod(round(time.time() * 1000), 2**31))
162
  for _ in range(_R.value):
163
    random_indices = np.random.choice(range(num_features), num_features**2)
164
    feature.append(tf.gather(inputs, random_indices, axis=1))
165
  pairwise_inputs = tf.stack(feature, axis=-1)
166
  return pairwise_inputs
167

168

169
def copy_keras_model(model):
170
  """Copy Keras model."""
171
  new_model = tf.keras.models.clone_model(model)
172
  for layer, new_layer in zip(model.layers, new_model.layers):
173
    weights = layer.get_weights()
174
    new_layer.set_weights(weights)
175
  return new_model
176

177

178
def freeze_keras_model(model):
179
  """Freeze part of the keras model."""
180
  model.trainable = True
181
  for layer in model.layers[::-1]:
182
    if "input_calibration" not in layer.name:
183
      layer.trainable = False  # freeze this layer
184

185

186
def compile_keras_model(model, init_lr=0.001):
187
  """Compile Keras model."""
188
  model.compile(
189
      loss=tf.keras.losses.SparseCategoricalCrossentropy(),
190
      metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
191
      optimizer=tf.keras.optimizers.Adam(learning_rate=init_lr))
192

193

194
def build_deepsets_joint_representation_model():
195
  """Build a pairwise joint distribution representation model."""
196
  # We first create the embedding model
197
  test_input = tf.keras.layers.Input(shape=(_NUM_INPUTS.value,))
198
  train_input = tf.keras.layers.Input(shape=(_NUM_INPUTS.value,))
199
  train_label = tf.keras.layers.Input(shape=(1,))
200

201
  # Obtain a mask variable. Output dimension [1, _NUM_INPUTS.value]
202
  mask = tf.ones((1, _NUM_INPUTS.value))
203
  one_row = tf.reshape(tf.gather(train_input, [0], axis=0), [-1])
204
  mask = mask * tf.cast(tf.not_equal(one_row, _PAD_VALUE.value), tf.float32)
205

206
  # Calibrate input if haven't done so
207
  calibrated_train_input = train_input
208
  calibrated_test_input = test_input
209
  calibration = tfl.layers.PWLCalibration(
210
      input_keypoints=np.linspace(0.0, 1.0, _NUM_CALIB_KEYS.value),
211
      units=_NUM_INPUTS.value,
212
      output_min=0.0,
213
      output_max=1.0,
214
      impute_missing=True,
215
      missing_input_value=_MISSING_VALUE.value,
216
      name="input_calibration")
217
  calibrated_train_input = calibration(train_input)
218
  calibrated_test_input = calibration(test_input)
219

220
  # Reshape the input to pair-wise format.
221
  # Output dimension [_BATCH_SIZE.value, _NUM_INPUTS.value**2, 2]
222
  pairwise_train_input = get_pairwise_inputs(calibrated_train_input)
223
  pairwise_test_input = get_pairwise_inputs(calibrated_test_input)
224

225
  # Obtain pairwise masks. Output dimesion [_NUM_INPUTS.value**2,]
226
  pairwise_mask = get_pairwise_inputs(mask)
227
  pairwise_mask = tf.reshape(tf.reduce_prod(pairwise_mask, axis=-1), [-1])
228

229
  # Obtain pairwise labels.
230
  # Output dimension
231
  # [_BATCH_SIZE.value, _NUM_INPUTS.value**2, _MAX_NUM_CLASSES.value]
232
  one_hot_train_label = tf.one_hot(
233
      tf.cast(train_label, tf.int32), _MAX_NUM_CLASSES.value)
234
  pairwise_train_label = tf.tile(one_hot_train_label,
235
                                 tf.constant([1, _NUM_INPUTS.value**2, 1]))
236

237
  # Concatenate pairwise inputs and labels.
238
  # Output dimension
239
  # [_BATCH_SIZE.value, _NUM_INPUTS.value**2, _MAX_NUM_CLASSES.value + 2]
240
  pairwise_train_input = tf.concat([pairwise_train_input, pairwise_train_label],
241
                                   axis=-1)
242

243
  # Obtain distribution representation. Output dimension
244
  # [_BATCH_SIZE.value, _NUM_INPUTS.value**2,
245
  #  _DISTRIBUTION_REPRESENTATION_DIM.value]
246
  batch_embedding = tf.keras.layers.Dense(
247
      _DISTRIBUTION_REPRESENTATION_DIM.value, activation="relu")(
248
          pairwise_train_input)
249
  for _ in range(_HIDDEN_LAYER.value - 1):
250
    batch_embedding = tf.keras.layers.Dense(
251
        _DISTRIBUTION_REPRESENTATION_DIM.value, activation="relu")(
252
            batch_embedding)
253

254
  # Average embeddings over the batch. Output dimension
255
  # [_NUM_INPUTS.value**2, _DISTRIBUTION_REPRESENTATION_DIM.value].
256
  mean_distribution_embedding = tf.reduce_mean(batch_embedding, axis=0)
257

258
  outputs = []
259
  for pairwise_input in [pairwise_test_input, pairwise_train_input]:
260
    # [_NUM_INPUTS.value**2, _DISTRIBUTION_REPRESENTATION_DIM.value] ->
261
    # [_BATCH_SIZE.value, _NUM_INPUTS.value**2,
262
    #  _DISTRIBUTION_REPRESENTATION_DIM.value] via repetition.
263
    distribution_embedding = tf.tile(
264
        [mean_distribution_embedding],
265
        tf.stack([tf.shape(pairwise_input)[0],
266
                  tf.constant(1),
267
                  tf.constant(1)]))
268
    # Concatenate pairwise inputs and embeddings. Output shape
269
    # [_BATCH_SIZE.value, _NUM_INPUTS.value**2,
270
    #  2 + _DISTRIBUTION_REPRESENTATION_DIM.value]
271
    concat_input = tf.concat([pairwise_input, distribution_embedding], axis=-1)
272

273
    # Apply a common function to each pair. Output shape
274
    # [_BATCH_SIZE.value, _NUM_INPUTS.value**2, _DEEPSETS_LAYER_UNITS.value]
275
    pairwise_output = tf.keras.layers.Dense(
276
        _DEEPSETS_LAYER_UNITS.value, activation="relu")(
277
            concat_input)
278
    for _ in range(_HIDDEN_LAYER.value - 1):
279
      pairwise_output = tf.keras.layers.Dense(
280
          _DEEPSETS_LAYER_UNITS.value, activation="relu")(
281
              pairwise_output)
282

283
    # Average pair-wise outputs across valid pairs.
284
    # Output shape [_BATCH_SIZE.value, _DEEPSETS_LAYER_UNITS.value]
285
    average_outputs = tf.tensordot(pairwise_mask, pairwise_output, [[0], [1]])
286
    average_outputs = average_outputs / tf.reduce_sum(pairwise_mask)
287

288
    # Use several dense layers to get the final output
289
    final_output = tf.keras.layers.Dense(
290
        _OUTPUT_LAYER_UNITS.value, activation="relu")(
291
            average_outputs)
292
    for i in range(_HIDDEN_LAYER.value - 1):
293
      final_output = tf.keras.layers.Dense(
294
          _OUTPUT_LAYER_UNITS.value, activation="relu")(
295
              final_output)
296
    outputs.append(final_output)
297

298
  test_outputs = tf.math.l2_normalize(outputs[0], axis=1)
299
  train_outputs = tf.math.l2_normalize(outputs[1], axis=1)
300
  similarity_matrix = tf.exp(
301
      tf.matmul(test_outputs, tf.transpose(train_outputs)))
302

303
  similarity_list = []
304
  for i in range(_MAX_NUM_CLASSES.value):
305
    mask = tf.cast(tf.squeeze(tf.equal(train_label, i)), tf.float32)
306
    similarity_list.append(similarity_matrix * mask)
307

308
  similarity = [
309
      tf.reduce_mean(s, axis=1, keepdims=True) for s in similarity_list
310
  ]
311
  sum_similarity = tf.reduce_sum(
312
      tf.concat(similarity, axis=1), axis=1, keepdims=True)
313
  final_output = [similarity / sum_similarity for similarity in similarity_list]
314
  final_output = tf.concat(final_output, axis=1)
315

316
  keras_model = tf.keras.models.Model(
317
      inputs=[test_input, train_input, train_label], outputs=final_output)
318
  compile_keras_model(keras_model)
319
  return keras_model
320

321

322
#############################################################################
323
# Training and evaluation
324
#############################################################################
325
def prepare_training_examples(data):
326
  """Prepare training examples."""
327
  features = tf.convert_to_tensor(data[:, 1:], dtype=tf.float32)
328
  labels = tf.convert_to_tensor(data[:, 0], dtype=tf.float32)
329

330
  # Pad features
331
  features = pad_features(
332
      features, _NUM_INPUTS.value, pad_value=_PAD_VALUE.value)
333

334
  half = int(features.shape[0] / 2)
335
  inputs = [features[half:], features[:half], labels[:half]]
336
  true_labels = labels[half:]
337

338
  return inputs, true_labels
339

340

341
def prepare_testing_examples(data, test_data):
342
  """Prepare testing examples."""
343
  features = tf.convert_to_tensor(data[:, 1:], dtype=tf.float32)
344
  labels = tf.convert_to_tensor(data[:, 0], dtype=tf.float32)
345
  test_features = tf.convert_to_tensor(test_data[:, 1:], dtype=tf.float32)
346
  test_labels = tf.convert_to_tensor(test_data[:, 0], dtype=tf.float32)
347

348
  # Pad features
349
  features = pad_features(
350
      features, _NUM_INPUTS.value, pad_value=_PAD_VALUE.value)
351
  test_features = pad_features(
352
      test_features, _NUM_INPUTS.value, pad_value=_PAD_VALUE.value)
353

354
  test_inputs = [test_features, features, labels]
355
  return test_inputs, test_labels
356

357

358
def train_model(models,
359
                data,
360
                batch_size,
361
                epochs,
362
                init_lr=0.001,
363
                compile_model=False):
364
  """Train model."""
365
  inputs, labels = prepare_training_examples(data)
366

367
  if compile_model:
368
    compile_keras_model(models, init_lr)
369
  model = models
370

371
  model.fit(inputs, labels, batch_size=batch_size, epochs=epochs, verbose=0)
372

373

374
def fine_tune_model(pretrain_models, fine_tune_task, tune_epochs, init_lr):
375
  """Fine-tune pretrained model."""
376
  # Make a copy of the pretrained model
377
  models = copy_keras_model(pretrain_models)
378
  freeze_keras_model(models)
379

380
  train_model(
381
      models,
382
      fine_tune_task,
383
      min(_BATCH_SIZE.value, _NUM_FINE_TUNE.value),
384
      tune_epochs,
385
      init_lr=init_lr,
386
      compile_model=True)
387
  return models
388

389

390
def evaluate_model(models, data, eval_data):
391
  """Evaluate model."""
392
  inputs, labels = prepare_testing_examples(data, eval_data)
393

394
  predictions = models(inputs)
395
  metrics = compute_metrics(labels, predictions)
396
  return np.array(metrics)
397

398

399
def pretrain_model(models, datasets, batch_size):
400
  """Pretrain model."""
401
  batch = 0  # Initialize batch counter
402
  while batch < _PRETRAIN_BATCHES.value:
403
    # Randomly select a task
404
    data_id = random.sample(range(len(datasets)), 1)[0]
405
    data = datasets[data_id]
406
    task_id = random.sample(range(len(data)), 1)[0]
407
    task = data[task_id]
408
    # Randomly select a batch
409
    batch_task = task[random.sample(range(task.shape[0]), batch_size)]
410
    # Train the model on this batch
411
    train_model(models, batch_task, batch_size, 1, init_lr=0.001)
412
    batch += 1
413

414

415
def run_simulation(pretrain_data, fine_tune_data, test_data, test_tasks):
416
  """Run simulation."""
417
  metrics = {"test": [list() for _ in range(len(test_data))]}
418

419
  # Pretraining
420
  pretrain_models = build_deepsets_joint_representation_model()
421
  pretrain_model(pretrain_models, pretrain_data, _BATCH_SIZE.value)
422

423
  for row in test_tasks:
424
    col_indices = [0] + [ind + 1 for ind in row[1:]]
425

426
    # Fine-tuning/direct training
427
    tune_task = fine_tune_data[row[0]][:, col_indices]
428
    models = fine_tune_model(pretrain_models, tune_task, _TUNE_EPOCHS.value,
429
                             0.001)
430
    # Evaluation
431
    metrics["test"][row[0]].append(
432
        evaluate_model(models, tune_task, test_data[row[0]][:, col_indices]))
433

434
  # Print metrics
435
  avg_metrics = average_metrics(metrics)
436
  print_metrics(avg_metrics["test"], "test")
437

438

439
def main(_):
440
  openml_datasets, openml_data_names = load_openml_data()
441

442
  # Prepare datasets
443
  target_names = openml_data_names[_OPENML_TEST_ID.value]
444
  target_data = openml_datasets[target_names]
445

446
  pretrain_data = [
447
      val for key, val in openml_datasets.items() if key not in target_names
448
  ]
449
  fine_tune_data = truncate_data(target_data, range(_NUM_FINE_TUNE.value))
450

451
  test_range = range(_NUM_FINE_TUNE.value, target_data[0].shape[0])
452
  num_features = target_data[0].shape[1] - 1
453
  test_tasks = [[task_id] + list(range(num_features))
454
                for task_id in range(len(target_data))]
455
  test_data = truncate_data(target_data, test_range)
456

457
  run_simulation(pretrain_data, fine_tune_data, test_data, test_tasks)
458

459

460
if __name__ == "__main__":
461
  app.run(main)
462

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.