google-research
181 строка · 7.0 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# coding=utf-8
17# Copyright 2023 The Google Research Authors.
18#
19# Licensed under the Apache License, Version 2.0 (the "License");
20# you may not use this file except in compliance with the License.
21# You may obtain a copy of the License at
22#
23# http://www.apache.org/licenses/LICENSE-2.0
24#
25# Unless required by applicable law or agreed to in writing, software
26# distributed under the License is distributed on an "AS IS" BASIS,
27# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28# See the License for the specific language governing permissions and
29# limitations under the License.
30
31"""Running a random_search on the UGSL components."""
32
33import copy
34import os
35import random
36from absl import flags
37
38from ml_collections import config_dict
39from ml_collections import config_flags
40
41_CONFIG = config_flags.DEFINE_config_file(
42"config",
43os.path.join(os.path.dirname(__file__), "config.py"),
44"Path to file containing configuration hyperparameters. "
45"File must define method `get_config()` to return an instance of "
46"`config_dict.ConfigDict`",
47)
48_DATASET = flags.DEFINE_string(
49"dataset", None, "The name of the dataset.", required=True
50)
51_NFEATS = flags.DEFINE_integer(
52"nfeats", None, "Number of node features.", required=True
53)
54_NRUNS = flags.DEFINE_integer("nruns", 5, "Number of runs.", required=False)
55_EXPERIMENT_DIR = flags.DEFINE_string(
56"experiment_dir", None, "The model directory.", required=True
57)
58
59
60def sample_random_configs():
61"""Returns list of random configurations based on flags in this file."""
62default_config: config_dict.ConfigDict = _CONFIG.value
63dataset: str = _DATASET.value
64num_runs: int = _NRUNS.value
65num_feats: int = _NFEATS.value
66experiment_dir: str = _EXPERIMENT_DIR.value
67configs = []
68for run in range(num_runs):
69config = copy.deepcopy(default_config)
70config.dataset.name = dataset
71config.run.learning_rate = random.uniform(1e-3, 1e-1)
72config.run.weight_decay = random.uniform(5e-4, 5e-2)
73config.run.model_dir = os.path.join(experiment_dir, str(run))
74config = set_edge_scorer_config(config, num_feats)
75config = set_sparsifier_config(config)
76config = set_processor_config(config)
77config = set_encoder_config(config)
78config = set_regularizer_config(config)
79config = get_unsupervised_config(config)
80config = get_positional_encoding_config(config)
81configs.append(config)
82
83return configs
84
85
86def set_edge_scorer_config(config, num_feats):
87"""Sets the edge scorer config."""
88config.model.edge_scorer_cfg.name = random.choice(["mlp", "fp", "attentive"])
89config.model.edge_scorer_cfg.nlayers = random.choice([1, 2])
90# if es is mlp, then:
91config.model.edge_scorer_cfg.nheads = random.choice([1, 2, 4])
92config.model.edge_scorer_cfg.activation = random.choice(["relu", "tanh"])
93config.model.edge_scorer_cfg.dropout_rate = random.uniform(0.0, 75e-2)
94config.model.edge_scorer_cfg.initialization = random.choice(
95["method1", "method2"]
96)
97config.model.edge_scorer_cfg.hidden_size = random.choice([500, num_feats])
98config.model.edge_scorer_cfg.output_size = random.choice([500, num_feats])
99# if es is attentive
100config.model.edge_scorer_cfg.nheads = random.choice([1, 2, 4])
101return config
102
103
104def set_sparsifier_config(config):
105"""Sets the sparsifier config."""
106config.model.sparsifier_cfg.name = random.choice(["knn", "dilated-knn"])
107config.model.sparsifier_cfg.k = random.choice([20, 25, 30])
108config.model.sparsifier_cfg.d = random.choice([2, 3])
109config.model.sparsifier_cfg.random_dilation = bool(random.getrandbits(1))
110return config
111
112
113def set_processor_config(config):
114"""Sets the preprocessor config."""
115config.model.processor_cfg.name = random.choice(
116["none", "symmetrize", "activation", "activation-symmetrize"]
117)
118config.model.processor_cfg.activation = random.choice(["relu", "elu"])
119config.model.merger_cfg.dropout_rate = random.uniform(0.0, 75e-2)
120return config
121
122
123def set_encoder_config(config):
124"""Sets the encoder config."""
125config.model.encoder_cfg.name = random.choice(["gcn", "gin"])
126config.model.encoder_cfg.hidden_units = random.choice([16, 32, 64, 128])
127config.model.encoder_cfg.activation = random.choice(["relu", "tanh"])
128config.model.encoder_cfg.dropout_rate = random.uniform(0.0, 75e-2)
129return config
130
131
132def set_regularizer_config(config):
133"""Sets the regularizer config."""
134config.model.regularizer_cfg.closeness_enable = bool(random.getrandbits(1))
135config.model.regularizer_cfg.smoothness_enable = bool(random.getrandbits(1))
136config.model.regularizer_cfg.sparseconnect_enable = bool(
137random.getrandbits(1)
138)
139config.model.regularizer_cfg.logbarrier_enable = bool(random.getrandbits(1))
140config.model.regularizer_cfg.information_enable = bool(random.getrandbits(1))
141config.model.regularizer_cfg.closeness_w = random.uniform(0.0, 20.0)
142config.model.regularizer_cfg.smoothness_w = random.uniform(0.0, 20.0)
143config.model.regularizer_cfg.sparseconnect_w = random.uniform(0.0, 20.0)
144config.model.regularizer_cfg.logbarrier_w = random.uniform(0.0, 20.0)
145config.model.regularizer_cfg.information_w = random.uniform(0.0, 20.0)
146return config
147
148
149def get_unsupervised_config(config):
150"""Sets the unsupervised loss config."""
151# Contrastive loss
152config.model.unsupervised_cfg.contrastive_cfg.enable = bool(
153random.getrandbits(1))
154config.model.unsupervised_cfg.contrastive_cfg.w = random.uniform(0.0, 20.0)
155config.model.unsupervised_cfg.contrastive_cfg.feature_mask_rate = (
156random.uniform(1e-2, 75e-2))
157config.model.unsupervised_cfg.contrastive_cfg.temperature = random.uniform(
1580.1, 1.0)
159config.model.unsupervised_cfg.contrastive_cfg.tau = random.uniform(0.0, 0.2)
160
161# Denoising loss
162config.model.unsupervised_cfg.denoising_cfg.enable = bool(
163random.getrandbits(1))
164config.model.unsupervised_cfg.denoising_cfg.w = random.uniform(0.0, 20.0)
165config.model.unsupervised_cfg.denoising_cfg.dropout_rate = random.uniform(
1660.0, 0.75)
167config.model.unsupervised_cfg.denoising_cfg.hidden_units = random.choice(
168[512, 1024])
169config.model.unsupervised_cfg.denoising_cfg.ones_ratio = random.choice(
170[1, 5, 10])
171config.model.unsupervised_cfg.denoising_cfg.negative_ratio = random.choice(
172[1, 5])
173
174return config
175
176
177def get_positional_encoding_config(config):
178"""Sets the positional encoding config."""
179config.dataset.add_wl_position_encoding = bool(random.getrandbits(1))
180config.dataset.add_spectral_encoding = bool(random.getrandbits(1))
181return config
182