google-research

Форк
0
258 строк · 7.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Trains a model on a source dataset."""
17

18
import argparse
19
import os
20

21
from active_selective_prediction.utils import data_util
22
from active_selective_prediction.utils import general_util
23
from active_selective_prediction.utils import model_util
24
import tensorflow as tf
25

26

27
def main():
28
  parser = argparse.ArgumentParser(
29
      description='pipeline for detecting dataset shift'
30
  )
31
  parser.add_argument('--gpu', default='0', type=str, help='which gpu to use.')
32
  parser.add_argument(
33
      '--seed', default=100, type=int, help='set a fixed random seed.'
34
  )
35
  parser.add_argument(
36
      '--dataset',
37
      default='color_mnist',
38
      choices=[
39
          'cifar10',
40
          'domainnet',
41
          'color_mnist',
42
          'fmow',
43
          'amazon_review',
44
          'otto',
45
      ],
46
      type=str,
47
      help='which dataset to train a model',
48
  )
49
  parser.add_argument(
50
      '--save-dir',
51
      default='./checkpoints/standard_supervised/',
52
      type=str,
53
      help='the dir to save trained model',
54
  )
55
  args = parser.parse_args()
56
  state = {k: v for k, v in args.__dict__.items()}
57
  print(state)
58
  seed = args.seed
59
  dataset = args.dataset
60
  save_dir = args.save_dir
61
  os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
62
  gpus = tf.config.experimental.list_physical_devices('GPU')
63
  tf.config.experimental.set_memory_growth(gpus[0], True)
64
  general_util.set_random_seed(seed)
65
  if not os.path.exists(save_dir):
66
    os.makedirs(save_dir)
67

68
  if dataset == 'color_mnist':
69
    train_ds = data_util.get_color_mnist_dataset(
70
        split='train', batch_size=128, shuffle=True, drop_remainder=False
71
    )
72
    val_ds = data_util.get_color_mnist_dataset(
73
        split='test', batch_size=200, shuffle=False, drop_remainder=False
74
    )
75
    epochs = 20
76
    learning_rate = 1e-3
77
    num_classes = 10
78
    init_inputs, _ = next(iter(train_ds))
79
    input_shape = tuple(init_inputs.shape[1:])
80
    model = model_util.get_simple_convnet(
81
        input_shape=input_shape, num_classes=num_classes
82
    )
83
  elif dataset == 'cifar10':
84
    train_ds = data_util.get_cifar10_dataset(
85
        split='train',
86
        batch_size=128,
87
        shuffle=True,
88
        drop_remainder=False,
89
        data_augment=True,
90
    )
91
    val_ds = data_util.get_cifar10_dataset(
92
        split='test', batch_size=200, shuffle=False, drop_remainder=False
93
    )
94
    epochs = 200
95
    learning_rate = 1e-1
96
    num_classes = 10
97
    init_inputs, _ = next(iter(train_ds))
98
    input_shape = tuple(init_inputs.shape[1:])
99
    model = model_util.get_cifar_resnet(
100
        input_shape=input_shape, num_classes=num_classes
101
    )
102
  elif dataset == 'domainnet':
103
    train_ds = data_util.get_domainnet_dataset(
104
        domain_name='real',
105
        split='train',
106
        batch_size=128,
107
        shuffle=True,
108
        drop_remainder=False,
109
        data_augment=True,
110
    )
111
    val_ds = data_util.get_domainnet_dataset(
112
        domain_name='real',
113
        split='test',
114
        batch_size=128,
115
        shuffle=False,
116
        drop_remainder=False,
117
    )
118
    epochs = 50
119
    learning_rate = 1e-4
120
    num_classes = 345
121
    init_inputs, _ = next(iter(train_ds))
122
    input_shape = tuple(init_inputs.shape[1:])
123
    model = model_util.get_resnet50(
124
        input_shape=input_shape,
125
        num_classes=num_classes,
126
        weights='imagenet',
127
    )
128
  elif dataset == 'fmow':
129
    train_ds = data_util.get_fmow_dataset(
130
        split='train',
131
        batch_size=128,
132
        shuffle=True,
133
        drop_remainder=False,
134
        data_augment=True,
135
        include_meta=False,
136
    )
137
    val_ds = data_util.get_fmow_dataset(
138
        split='id_val',
139
        batch_size=128,
140
        shuffle=False,
141
        drop_remainder=False,
142
        include_meta=False,
143
    )
144
    epochs = 50
145
    learning_rate = 1e-4
146
    num_classes = 62
147
    init_inputs, _ = next(iter(train_ds))
148
    input_shape = tuple(init_inputs.shape[1:])
149
    model = model_util.get_densenet121(
150
        input_shape=input_shape,
151
        num_classes=num_classes,
152
        weights='imagenet',
153
    )
154
  elif dataset == 'amazon_review':
155
    train_ds = data_util.get_amazon_review_dataset(
156
        split='train',
157
        batch_size=128,
158
        shuffle=True,
159
        drop_remainder=False,
160
        include_meta=False,
161
    )
162
    val_ds = data_util.get_amazon_review_dataset(
163
        split='id_val',
164
        batch_size=128,
165
        shuffle=False,
166
        drop_remainder=False,
167
        include_meta=False,
168
    )
169
    epochs = 200
170
    learning_rate = 1e-3
171
    num_classes = 5
172
    train_ds_iter = iter(train_ds)
173
    init_inputs, _ = next(train_ds_iter)
174
    for _ in train_ds_iter:
175
      pass
176
    input_shape = tuple(init_inputs.shape[1:])
177
    model = model_util.get_roberta_mlp(
178
        input_shape=input_shape,
179
        num_classes=num_classes,
180
    )
181
  elif dataset == 'otto':
182
    train_ds = data_util.get_otto_dataset(
183
        split='train',
184
        batch_size=128,
185
        shuffle=True,
186
        drop_remainder=False,
187
    )
188
    val_ds = data_util.get_otto_dataset(
189
        split='val',
190
        batch_size=128,
191
        shuffle=False,
192
        drop_remainder=False,
193
    )
194
    epochs = 200
195
    learning_rate = 1e-3
196
    num_classes = 9
197
    train_ds_iter = iter(train_ds)
198
    init_inputs, _ = next(train_ds_iter)
199
    for _ in train_ds_iter:
200
      pass
201
    input_shape = tuple(init_inputs.shape[1:])
202
    model = model_util.get_simple_mlp(
203
        input_shape=input_shape,
204
        num_classes=num_classes,
205
    )
206
  else:
207
    raise ValueError(f'Unsupported dataset {dataset}!')
208
  # Builds model
209
  model(init_inputs)
210
  model.summary()
211
  if dataset == 'color_mnist':
212
    model.compile(
213
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
214
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
215
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
216
    )
217
    model.fit(train_ds, epochs=epochs, validation_data=val_ds)
218
  elif dataset == 'cifar10':
219
    model.compile(
220
        optimizer=tf.keras.optimizers.SGD(
221
            learning_rate=learning_rate, momentum=0.9
222
        ),
223
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
224
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
225
    )
226
    def scheduler_func(epoch, lr):
227
      if epoch == 80:
228
        lr *= 0.1
229
      elif epoch == 120:
230
        lr *= 0.1
231
      elif epoch == 160:
232
        lr *= 0.1
233
      elif epoch == 180:
234
        lr *= 0.5
235
      return lr
236
    lr_scheduler = tf.keras.callbacks.LearningRateScheduler(
237
        scheduler_func
238
    )
239
    callbacks = [lr_scheduler]
240
    model.fit(
241
        train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds
242
    )
243
  elif dataset in ['domainnet', 'fmow', 'amazon_review', 'otto']:
244
    model.compile(
245
        optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
246
        loss=tf.keras.losses.SparseCategoricalCrossentropy(),
247
        metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
248
    )
249
    model.fit(
250
        train_ds, epochs=epochs, validation_data=val_ds
251
    )
252
  model.save_weights(
253
      os.path.join(save_dir, f'{dataset}', 'checkpoint')
254
  )
255

256

257
if __name__ == '__main__':
258
  main()
259

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.