google-research
461 строка · 16.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Experimental code for "Distributino Embedding Network for Meta Learning" on OpenML data.
17"""
18
19from __future__ import absolute_import
20from __future__ import division
21from __future__ import print_function
22
23import os
24import random
25import time
26
27from absl import app
28from absl import flags
29import numpy as np
30import tensorflow as tf
31import tensorflow_lattice as tfl
32
33# Data preparation hparams
34_DATA_DIRECTORY = flags.DEFINE_string(
35"openml_data_directory", None,
36"Directory that stores the training data. This dicretory should only"
37"include training csv files, and nothing else.")
38_MAX_NUM_CLASSES = flags.DEFINE_integer("max_num_classes", 2,
39"max number of classes across tasks")
40_OPENML_TEST_ID = flags.DEFINE_integer("openml_test_id", 0,
41"id of openml data to be used as test.")
42_MISSING_VALUE = flags.DEFINE_float("missing_value", -1.0,
43"missing value in real data.")
44
45# Data generation hp arams
46_NUM_FINE_TUNE = flags.DEFINE_integer("num_fine_tune", 50,
47"number of fine-tuning examples.")
48_NUM_INPUTS = flags.DEFINE_integer("num_inputs", 25,
49"max number input dimensions across tasks.")
50_PAD_VALUE = flags.DEFINE_integer("pad_value", -10, "value used for padding.")
51_R = flags.DEFINE_integer("r", 2, "r value in embedding.")
52
53# Training hparams
54_PRETRAIN_BATCHES = flags.DEFINE_integer(
55"pretrain_batches", 10000, "number of steps to pretrain the model.")
56_TUNE_EPOCHS = flags.DEFINE_integer("tune_epochs", 10,
57"number of fine-tuning epochs.")
58_BATCH_SIZE = flags.DEFINE_integer(
59"batch_size", 64, "batch size for pretraining and finetuning.")
60
61# DEN hparams
62_NUM_CALIB_KEYS = flags.DEFINE_integer(
63"num_calib_keys", 10, "number of keypoints in the calibration layer.")
64_HIDDEN_LAYER = flags.DEFINE_integer("hidden_layer", 2,
65"depth of the h, phi and psi functions.")
66_DISTRIBUTION_REPRESENTATION_DIM = flags.DEFINE_integer(
67"distribution_representation_dim", 16, "width of the h function.")
68_DEEPSETS_LAYER_UNITS = flags.DEFINE_integer("deepsets_layer_units", 10,
69"width of the phi function.")
70_OUTPUT_LAYER_UNITS = flags.DEFINE_integer("output_layer_units", 8,
71"width of the psi function.")
72
73
74#############################################################################
75# Utils
76#############################################################################
77def load_openml_data():
78"""Loads data from CNS.
79
80Output a dictionary of the form {name: data}. Here data is a list of numpy
81arrays with the first column being labels and the rest being features.
82
83Returns:
84datasets: dictionary. Each element is (name, val) pair where name is the
85dataset name and val is a list containing binary classification tasks
86within this dataset.
87files: list of files in the directory.
88"""
89datasets = dict()
90files = os.listdir(_DATA_DIRECTORY.value)
91for file_name in files:
92with open(_DATA_DIRECTORY.value + file_name, "r") as ff:
93task = np.loadtxt(ff, delimiter=",", skiprows=1)
94np.random.shuffle(task)
95datasets[file_name] = [task]
96return datasets, files
97
98
99def truncate_data(data, indices):
100"""Truncates data using indices provided."""
101truncated_data = []
102for task in data:
103truncated_data.append(task[indices])
104return truncated_data
105
106
107def compute_metrics(labels, predictions):
108"""Computes metrics."""
109loss = tf.keras.losses.SparseCategoricalCrossentropy()
110res = [loss(labels, predictions).numpy()]
111
112metrics = [tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")]
113for m in metrics:
114_ = m.update_state(labels, predictions)
115res.append(m.result().numpy())
116return res
117
118
119def average_metrics(metrics):
120"""Average metrics."""
121avg_metrics = dict()
122for name, metric in metrics.items():
123avg_metrics[name] = []
124for m in metric:
125avg_metrics[name].append({
126"mean": np.mean(np.array(m), axis=0),
127"std": np.std(np.array(m), axis=0)
128})
129return avg_metrics
130
131
132def print_metrics(metrics, data_name):
133"""Print metrics."""
134metrics_name = ["loss", "accuracy", "auc"]
135for i, metric in enumerate(metrics):
136for mean, std, name in zip(metric["mean"], metric["std"], metrics_name):
137print(f"[metric] task{i}_{data_name}_{name}_mean={mean}")
138print(f"[metric] task{i}_{data_name}_{name}_std={std}")
139
140
141def pad_features(features, size, axis=1, pad_value=None):
142"""Pad features."""
143if pad_value is None: # Repeat columns
144num = features.shape[axis]
145repeat_indices = random.sample(range(num), size - num)
146repeat_features = tf.gather(features, repeat_indices, axis=axis)
147new_features = tf.concat([features, repeat_features], axis=axis)
148else: # Add padding values
149paddings = [[0, 0] for _ in features.shape]
150paddings[axis] = [0, size - features.shape[axis]]
151new_features = tf.pad(
152features, tf.constant(paddings), constant_values=pad_value)
153return new_features
154
155
156def get_pairwise_inputs(inputs):
157"""Reform inputs to pairwise format."""
158# [BATCH_SIZE, NUM_INPUTS] --> [BATCH_SIZE, NUM_INPUTS**2, _R.value]
159num_features = inputs.shape[1]
160feature = []
161np.random.seed(seed=np.mod(round(time.time() * 1000), 2**31))
162for _ in range(_R.value):
163random_indices = np.random.choice(range(num_features), num_features**2)
164feature.append(tf.gather(inputs, random_indices, axis=1))
165pairwise_inputs = tf.stack(feature, axis=-1)
166return pairwise_inputs
167
168
169def copy_keras_model(model):
170"""Copy Keras model."""
171new_model = tf.keras.models.clone_model(model)
172for layer, new_layer in zip(model.layers, new_model.layers):
173weights = layer.get_weights()
174new_layer.set_weights(weights)
175return new_model
176
177
178def freeze_keras_model(model):
179"""Freeze part of the keras model."""
180model.trainable = True
181for layer in model.layers[::-1]:
182if "input_calibration" not in layer.name:
183layer.trainable = False # freeze this layer
184
185
186def compile_keras_model(model, init_lr=0.001):
187"""Compile Keras model."""
188model.compile(
189loss=tf.keras.losses.SparseCategoricalCrossentropy(),
190metrics=[tf.keras.metrics.SparseCategoricalAccuracy(name="accuracy")],
191optimizer=tf.keras.optimizers.Adam(learning_rate=init_lr))
192
193
194def build_deepsets_joint_representation_model():
195"""Build a pairwise joint distribution representation model."""
196# We first create the embedding model
197test_input = tf.keras.layers.Input(shape=(_NUM_INPUTS.value,))
198train_input = tf.keras.layers.Input(shape=(_NUM_INPUTS.value,))
199train_label = tf.keras.layers.Input(shape=(1,))
200
201# Obtain a mask variable. Output dimension [1, _NUM_INPUTS.value]
202mask = tf.ones((1, _NUM_INPUTS.value))
203one_row = tf.reshape(tf.gather(train_input, [0], axis=0), [-1])
204mask = mask * tf.cast(tf.not_equal(one_row, _PAD_VALUE.value), tf.float32)
205
206# Calibrate input if haven't done so
207calibrated_train_input = train_input
208calibrated_test_input = test_input
209calibration = tfl.layers.PWLCalibration(
210input_keypoints=np.linspace(0.0, 1.0, _NUM_CALIB_KEYS.value),
211units=_NUM_INPUTS.value,
212output_min=0.0,
213output_max=1.0,
214impute_missing=True,
215missing_input_value=_MISSING_VALUE.value,
216name="input_calibration")
217calibrated_train_input = calibration(train_input)
218calibrated_test_input = calibration(test_input)
219
220# Reshape the input to pair-wise format.
221# Output dimension [_BATCH_SIZE.value, _NUM_INPUTS.value**2, 2]
222pairwise_train_input = get_pairwise_inputs(calibrated_train_input)
223pairwise_test_input = get_pairwise_inputs(calibrated_test_input)
224
225# Obtain pairwise masks. Output dimesion [_NUM_INPUTS.value**2,]
226pairwise_mask = get_pairwise_inputs(mask)
227pairwise_mask = tf.reshape(tf.reduce_prod(pairwise_mask, axis=-1), [-1])
228
229# Obtain pairwise labels.
230# Output dimension
231# [_BATCH_SIZE.value, _NUM_INPUTS.value**2, _MAX_NUM_CLASSES.value]
232one_hot_train_label = tf.one_hot(
233tf.cast(train_label, tf.int32), _MAX_NUM_CLASSES.value)
234pairwise_train_label = tf.tile(one_hot_train_label,
235tf.constant([1, _NUM_INPUTS.value**2, 1]))
236
237# Concatenate pairwise inputs and labels.
238# Output dimension
239# [_BATCH_SIZE.value, _NUM_INPUTS.value**2, _MAX_NUM_CLASSES.value + 2]
240pairwise_train_input = tf.concat([pairwise_train_input, pairwise_train_label],
241axis=-1)
242
243# Obtain distribution representation. Output dimension
244# [_BATCH_SIZE.value, _NUM_INPUTS.value**2,
245# _DISTRIBUTION_REPRESENTATION_DIM.value]
246batch_embedding = tf.keras.layers.Dense(
247_DISTRIBUTION_REPRESENTATION_DIM.value, activation="relu")(
248pairwise_train_input)
249for _ in range(_HIDDEN_LAYER.value - 1):
250batch_embedding = tf.keras.layers.Dense(
251_DISTRIBUTION_REPRESENTATION_DIM.value, activation="relu")(
252batch_embedding)
253
254# Average embeddings over the batch. Output dimension
255# [_NUM_INPUTS.value**2, _DISTRIBUTION_REPRESENTATION_DIM.value].
256mean_distribution_embedding = tf.reduce_mean(batch_embedding, axis=0)
257
258outputs = []
259for pairwise_input in [pairwise_test_input, pairwise_train_input]:
260# [_NUM_INPUTS.value**2, _DISTRIBUTION_REPRESENTATION_DIM.value] ->
261# [_BATCH_SIZE.value, _NUM_INPUTS.value**2,
262# _DISTRIBUTION_REPRESENTATION_DIM.value] via repetition.
263distribution_embedding = tf.tile(
264[mean_distribution_embedding],
265tf.stack([tf.shape(pairwise_input)[0],
266tf.constant(1),
267tf.constant(1)]))
268# Concatenate pairwise inputs and embeddings. Output shape
269# [_BATCH_SIZE.value, _NUM_INPUTS.value**2,
270# 2 + _DISTRIBUTION_REPRESENTATION_DIM.value]
271concat_input = tf.concat([pairwise_input, distribution_embedding], axis=-1)
272
273# Apply a common function to each pair. Output shape
274# [_BATCH_SIZE.value, _NUM_INPUTS.value**2, _DEEPSETS_LAYER_UNITS.value]
275pairwise_output = tf.keras.layers.Dense(
276_DEEPSETS_LAYER_UNITS.value, activation="relu")(
277concat_input)
278for _ in range(_HIDDEN_LAYER.value - 1):
279pairwise_output = tf.keras.layers.Dense(
280_DEEPSETS_LAYER_UNITS.value, activation="relu")(
281pairwise_output)
282
283# Average pair-wise outputs across valid pairs.
284# Output shape [_BATCH_SIZE.value, _DEEPSETS_LAYER_UNITS.value]
285average_outputs = tf.tensordot(pairwise_mask, pairwise_output, [[0], [1]])
286average_outputs = average_outputs / tf.reduce_sum(pairwise_mask)
287
288# Use several dense layers to get the final output
289final_output = tf.keras.layers.Dense(
290_OUTPUT_LAYER_UNITS.value, activation="relu")(
291average_outputs)
292for i in range(_HIDDEN_LAYER.value - 1):
293final_output = tf.keras.layers.Dense(
294_OUTPUT_LAYER_UNITS.value, activation="relu")(
295final_output)
296outputs.append(final_output)
297
298test_outputs = tf.math.l2_normalize(outputs[0], axis=1)
299train_outputs = tf.math.l2_normalize(outputs[1], axis=1)
300similarity_matrix = tf.exp(
301tf.matmul(test_outputs, tf.transpose(train_outputs)))
302
303similarity_list = []
304for i in range(_MAX_NUM_CLASSES.value):
305mask = tf.cast(tf.squeeze(tf.equal(train_label, i)), tf.float32)
306similarity_list.append(similarity_matrix * mask)
307
308similarity = [
309tf.reduce_mean(s, axis=1, keepdims=True) for s in similarity_list
310]
311sum_similarity = tf.reduce_sum(
312tf.concat(similarity, axis=1), axis=1, keepdims=True)
313final_output = [similarity / sum_similarity for similarity in similarity_list]
314final_output = tf.concat(final_output, axis=1)
315
316keras_model = tf.keras.models.Model(
317inputs=[test_input, train_input, train_label], outputs=final_output)
318compile_keras_model(keras_model)
319return keras_model
320
321
322#############################################################################
323# Training and evaluation
324#############################################################################
325def prepare_training_examples(data):
326"""Prepare training examples."""
327features = tf.convert_to_tensor(data[:, 1:], dtype=tf.float32)
328labels = tf.convert_to_tensor(data[:, 0], dtype=tf.float32)
329
330# Pad features
331features = pad_features(
332features, _NUM_INPUTS.value, pad_value=_PAD_VALUE.value)
333
334half = int(features.shape[0] / 2)
335inputs = [features[half:], features[:half], labels[:half]]
336true_labels = labels[half:]
337
338return inputs, true_labels
339
340
341def prepare_testing_examples(data, test_data):
342"""Prepare testing examples."""
343features = tf.convert_to_tensor(data[:, 1:], dtype=tf.float32)
344labels = tf.convert_to_tensor(data[:, 0], dtype=tf.float32)
345test_features = tf.convert_to_tensor(test_data[:, 1:], dtype=tf.float32)
346test_labels = tf.convert_to_tensor(test_data[:, 0], dtype=tf.float32)
347
348# Pad features
349features = pad_features(
350features, _NUM_INPUTS.value, pad_value=_PAD_VALUE.value)
351test_features = pad_features(
352test_features, _NUM_INPUTS.value, pad_value=_PAD_VALUE.value)
353
354test_inputs = [test_features, features, labels]
355return test_inputs, test_labels
356
357
358def train_model(models,
359data,
360batch_size,
361epochs,
362init_lr=0.001,
363compile_model=False):
364"""Train model."""
365inputs, labels = prepare_training_examples(data)
366
367if compile_model:
368compile_keras_model(models, init_lr)
369model = models
370
371model.fit(inputs, labels, batch_size=batch_size, epochs=epochs, verbose=0)
372
373
374def fine_tune_model(pretrain_models, fine_tune_task, tune_epochs, init_lr):
375"""Fine-tune pretrained model."""
376# Make a copy of the pretrained model
377models = copy_keras_model(pretrain_models)
378freeze_keras_model(models)
379
380train_model(
381models,
382fine_tune_task,
383min(_BATCH_SIZE.value, _NUM_FINE_TUNE.value),
384tune_epochs,
385init_lr=init_lr,
386compile_model=True)
387return models
388
389
390def evaluate_model(models, data, eval_data):
391"""Evaluate model."""
392inputs, labels = prepare_testing_examples(data, eval_data)
393
394predictions = models(inputs)
395metrics = compute_metrics(labels, predictions)
396return np.array(metrics)
397
398
399def pretrain_model(models, datasets, batch_size):
400"""Pretrain model."""
401batch = 0 # Initialize batch counter
402while batch < _PRETRAIN_BATCHES.value:
403# Randomly select a task
404data_id = random.sample(range(len(datasets)), 1)[0]
405data = datasets[data_id]
406task_id = random.sample(range(len(data)), 1)[0]
407task = data[task_id]
408# Randomly select a batch
409batch_task = task[random.sample(range(task.shape[0]), batch_size)]
410# Train the model on this batch
411train_model(models, batch_task, batch_size, 1, init_lr=0.001)
412batch += 1
413
414
415def run_simulation(pretrain_data, fine_tune_data, test_data, test_tasks):
416"""Run simulation."""
417metrics = {"test": [list() for _ in range(len(test_data))]}
418
419# Pretraining
420pretrain_models = build_deepsets_joint_representation_model()
421pretrain_model(pretrain_models, pretrain_data, _BATCH_SIZE.value)
422
423for row in test_tasks:
424col_indices = [0] + [ind + 1 for ind in row[1:]]
425
426# Fine-tuning/direct training
427tune_task = fine_tune_data[row[0]][:, col_indices]
428models = fine_tune_model(pretrain_models, tune_task, _TUNE_EPOCHS.value,
4290.001)
430# Evaluation
431metrics["test"][row[0]].append(
432evaluate_model(models, tune_task, test_data[row[0]][:, col_indices]))
433
434# Print metrics
435avg_metrics = average_metrics(metrics)
436print_metrics(avg_metrics["test"], "test")
437
438
439def main(_):
440openml_datasets, openml_data_names = load_openml_data()
441
442# Prepare datasets
443target_names = openml_data_names[_OPENML_TEST_ID.value]
444target_data = openml_datasets[target_names]
445
446pretrain_data = [
447val for key, val in openml_datasets.items() if key not in target_names
448]
449fine_tune_data = truncate_data(target_data, range(_NUM_FINE_TUNE.value))
450
451test_range = range(_NUM_FINE_TUNE.value, target_data[0].shape[0])
452num_features = target_data[0].shape[1] - 1
453test_tasks = [[task_id] + list(range(num_features))
454for task_id in range(len(target_data))]
455test_data = truncate_data(target_data, test_range)
456
457run_simulation(pretrain_data, fine_tune_data, test_data, test_tasks)
458
459
460if __name__ == "__main__":
461app.run(main)
462