google-research

Форк
0
/
random_search_lib.py 
181 строка · 7.0 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# coding=utf-8
17
# Copyright 2023 The Google Research Authors.
18
#
19
# Licensed under the Apache License, Version 2.0 (the "License");
20
# you may not use this file except in compliance with the License.
21
# You may obtain a copy of the License at
22
#
23
#     http://www.apache.org/licenses/LICENSE-2.0
24
#
25
# Unless required by applicable law or agreed to in writing, software
26
# distributed under the License is distributed on an "AS IS" BASIS,
27
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
# See the License for the specific language governing permissions and
29
# limitations under the License.
30

31
"""Running a random_search on the UGSL components."""
32

33
import copy
34
import os
35
import random
36
from absl import flags
37

38
from ml_collections import config_dict
39
from ml_collections import config_flags
40

41
_CONFIG = config_flags.DEFINE_config_file(
42
    "config",
43
    os.path.join(os.path.dirname(__file__), "config.py"),
44
    "Path to file containing configuration hyperparameters. "
45
    "File must define method `get_config()` to return an instance of "
46
    "`config_dict.ConfigDict`",
47
)
48
_DATASET = flags.DEFINE_string(
49
    "dataset", None, "The name of the dataset.", required=True
50
)
51
_NFEATS = flags.DEFINE_integer(
52
    "nfeats", None, "Number of node features.", required=True
53
)
54
_NRUNS = flags.DEFINE_integer("nruns", 5, "Number of runs.", required=False)
55
_EXPERIMENT_DIR = flags.DEFINE_string(
56
    "experiment_dir", None, "The model directory.", required=True
57
)
58

59

60
def sample_random_configs():
61
  """Returns list of random configurations based on flags in this file."""
62
  default_config: config_dict.ConfigDict = _CONFIG.value
63
  dataset: str = _DATASET.value
64
  num_runs: int = _NRUNS.value
65
  num_feats: int = _NFEATS.value
66
  experiment_dir: str = _EXPERIMENT_DIR.value
67
  configs = []
68
  for run in range(num_runs):
69
    config = copy.deepcopy(default_config)
70
    config.dataset.name = dataset
71
    config.run.learning_rate = random.uniform(1e-3, 1e-1)
72
    config.run.weight_decay = random.uniform(5e-4, 5e-2)
73
    config.run.model_dir = os.path.join(experiment_dir, str(run))
74
    config = set_edge_scorer_config(config, num_feats)
75
    config = set_sparsifier_config(config)
76
    config = set_processor_config(config)
77
    config = set_encoder_config(config)
78
    config = set_regularizer_config(config)
79
    config = get_unsupervised_config(config)
80
    config = get_positional_encoding_config(config)
81
    configs.append(config)
82

83
  return configs
84

85

86
def set_edge_scorer_config(config, num_feats):
87
  """Sets the edge scorer config."""
88
  config.model.edge_scorer_cfg.name = random.choice(["mlp", "fp", "attentive"])
89
  config.model.edge_scorer_cfg.nlayers = random.choice([1, 2])
90
  # if es is mlp, then:
91
  config.model.edge_scorer_cfg.nheads = random.choice([1, 2, 4])
92
  config.model.edge_scorer_cfg.activation = random.choice(["relu", "tanh"])
93
  config.model.edge_scorer_cfg.dropout_rate = random.uniform(0.0, 75e-2)
94
  config.model.edge_scorer_cfg.initialization = random.choice(
95
      ["method1", "method2"]
96
  )
97
  config.model.edge_scorer_cfg.hidden_size = random.choice([500, num_feats])
98
  config.model.edge_scorer_cfg.output_size = random.choice([500, num_feats])
99
  # if es is attentive
100
  config.model.edge_scorer_cfg.nheads = random.choice([1, 2, 4])
101
  return config
102

103

104
def set_sparsifier_config(config):
105
  """Sets the sparsifier config."""
106
  config.model.sparsifier_cfg.name = random.choice(["knn", "dilated-knn"])
107
  config.model.sparsifier_cfg.k = random.choice([20, 25, 30])
108
  config.model.sparsifier_cfg.d = random.choice([2, 3])
109
  config.model.sparsifier_cfg.random_dilation = bool(random.getrandbits(1))
110
  return config
111

112

113
def set_processor_config(config):
114
  """Sets the preprocessor config."""
115
  config.model.processor_cfg.name = random.choice(
116
      ["none", "symmetrize", "activation", "activation-symmetrize"]
117
  )
118
  config.model.processor_cfg.activation = random.choice(["relu", "elu"])
119
  config.model.merger_cfg.dropout_rate = random.uniform(0.0, 75e-2)
120
  return config
121

122

123
def set_encoder_config(config):
124
  """Sets the encoder config."""
125
  config.model.encoder_cfg.name = random.choice(["gcn", "gin"])
126
  config.model.encoder_cfg.hidden_units = random.choice([16, 32, 64, 128])
127
  config.model.encoder_cfg.activation = random.choice(["relu", "tanh"])
128
  config.model.encoder_cfg.dropout_rate = random.uniform(0.0, 75e-2)
129
  return config
130

131

132
def set_regularizer_config(config):
133
  """Sets the regularizer config."""
134
  config.model.regularizer_cfg.closeness_enable = bool(random.getrandbits(1))
135
  config.model.regularizer_cfg.smoothness_enable = bool(random.getrandbits(1))
136
  config.model.regularizer_cfg.sparseconnect_enable = bool(
137
      random.getrandbits(1)
138
  )
139
  config.model.regularizer_cfg.logbarrier_enable = bool(random.getrandbits(1))
140
  config.model.regularizer_cfg.information_enable = bool(random.getrandbits(1))
141
  config.model.regularizer_cfg.closeness_w = random.uniform(0.0, 20.0)
142
  config.model.regularizer_cfg.smoothness_w = random.uniform(0.0, 20.0)
143
  config.model.regularizer_cfg.sparseconnect_w = random.uniform(0.0, 20.0)
144
  config.model.regularizer_cfg.logbarrier_w = random.uniform(0.0, 20.0)
145
  config.model.regularizer_cfg.information_w = random.uniform(0.0, 20.0)
146
  return config
147

148

149
def get_unsupervised_config(config):
150
  """Sets the unsupervised loss config."""
151
  # Contrastive loss
152
  config.model.unsupervised_cfg.contrastive_cfg.enable = bool(
153
      random.getrandbits(1))
154
  config.model.unsupervised_cfg.contrastive_cfg.w = random.uniform(0.0, 20.0)
155
  config.model.unsupervised_cfg.contrastive_cfg.feature_mask_rate = (
156
      random.uniform(1e-2, 75e-2))
157
  config.model.unsupervised_cfg.contrastive_cfg.temperature = random.uniform(
158
      0.1, 1.0)
159
  config.model.unsupervised_cfg.contrastive_cfg.tau = random.uniform(0.0, 0.2)
160

161
  # Denoising loss
162
  config.model.unsupervised_cfg.denoising_cfg.enable = bool(
163
      random.getrandbits(1))
164
  config.model.unsupervised_cfg.denoising_cfg.w = random.uniform(0.0, 20.0)
165
  config.model.unsupervised_cfg.denoising_cfg.dropout_rate = random.uniform(
166
      0.0, 0.75)
167
  config.model.unsupervised_cfg.denoising_cfg.hidden_units = random.choice(
168
      [512, 1024])
169
  config.model.unsupervised_cfg.denoising_cfg.ones_ratio = random.choice(
170
      [1, 5, 10])
171
  config.model.unsupervised_cfg.denoising_cfg.negative_ratio = random.choice(
172
      [1, 5])
173

174
  return config
175

176

177
def get_positional_encoding_config(config):
178
  """Sets the positional encoding config."""
179
  config.dataset.add_wl_position_encoding = bool(random.getrandbits(1))
180
  config.dataset.add_spectral_encoding = bool(random.getrandbits(1))
181
  return config
182

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.