google-research
258 строк · 7.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Trains a model on a source dataset."""
17
18import argparse
19import os
20
21from active_selective_prediction.utils import data_util
22from active_selective_prediction.utils import general_util
23from active_selective_prediction.utils import model_util
24import tensorflow as tf
25
26
27def main():
28parser = argparse.ArgumentParser(
29description='pipeline for detecting dataset shift'
30)
31parser.add_argument('--gpu', default='0', type=str, help='which gpu to use.')
32parser.add_argument(
33'--seed', default=100, type=int, help='set a fixed random seed.'
34)
35parser.add_argument(
36'--dataset',
37default='color_mnist',
38choices=[
39'cifar10',
40'domainnet',
41'color_mnist',
42'fmow',
43'amazon_review',
44'otto',
45],
46type=str,
47help='which dataset to train a model',
48)
49parser.add_argument(
50'--save-dir',
51default='./checkpoints/standard_supervised/',
52type=str,
53help='the dir to save trained model',
54)
55args = parser.parse_args()
56state = {k: v for k, v in args.__dict__.items()}
57print(state)
58seed = args.seed
59dataset = args.dataset
60save_dir = args.save_dir
61os.environ['CUDA_VISIBLE_DEVICES'] = args.gpu
62gpus = tf.config.experimental.list_physical_devices('GPU')
63tf.config.experimental.set_memory_growth(gpus[0], True)
64general_util.set_random_seed(seed)
65if not os.path.exists(save_dir):
66os.makedirs(save_dir)
67
68if dataset == 'color_mnist':
69train_ds = data_util.get_color_mnist_dataset(
70split='train', batch_size=128, shuffle=True, drop_remainder=False
71)
72val_ds = data_util.get_color_mnist_dataset(
73split='test', batch_size=200, shuffle=False, drop_remainder=False
74)
75epochs = 20
76learning_rate = 1e-3
77num_classes = 10
78init_inputs, _ = next(iter(train_ds))
79input_shape = tuple(init_inputs.shape[1:])
80model = model_util.get_simple_convnet(
81input_shape=input_shape, num_classes=num_classes
82)
83elif dataset == 'cifar10':
84train_ds = data_util.get_cifar10_dataset(
85split='train',
86batch_size=128,
87shuffle=True,
88drop_remainder=False,
89data_augment=True,
90)
91val_ds = data_util.get_cifar10_dataset(
92split='test', batch_size=200, shuffle=False, drop_remainder=False
93)
94epochs = 200
95learning_rate = 1e-1
96num_classes = 10
97init_inputs, _ = next(iter(train_ds))
98input_shape = tuple(init_inputs.shape[1:])
99model = model_util.get_cifar_resnet(
100input_shape=input_shape, num_classes=num_classes
101)
102elif dataset == 'domainnet':
103train_ds = data_util.get_domainnet_dataset(
104domain_name='real',
105split='train',
106batch_size=128,
107shuffle=True,
108drop_remainder=False,
109data_augment=True,
110)
111val_ds = data_util.get_domainnet_dataset(
112domain_name='real',
113split='test',
114batch_size=128,
115shuffle=False,
116drop_remainder=False,
117)
118epochs = 50
119learning_rate = 1e-4
120num_classes = 345
121init_inputs, _ = next(iter(train_ds))
122input_shape = tuple(init_inputs.shape[1:])
123model = model_util.get_resnet50(
124input_shape=input_shape,
125num_classes=num_classes,
126weights='imagenet',
127)
128elif dataset == 'fmow':
129train_ds = data_util.get_fmow_dataset(
130split='train',
131batch_size=128,
132shuffle=True,
133drop_remainder=False,
134data_augment=True,
135include_meta=False,
136)
137val_ds = data_util.get_fmow_dataset(
138split='id_val',
139batch_size=128,
140shuffle=False,
141drop_remainder=False,
142include_meta=False,
143)
144epochs = 50
145learning_rate = 1e-4
146num_classes = 62
147init_inputs, _ = next(iter(train_ds))
148input_shape = tuple(init_inputs.shape[1:])
149model = model_util.get_densenet121(
150input_shape=input_shape,
151num_classes=num_classes,
152weights='imagenet',
153)
154elif dataset == 'amazon_review':
155train_ds = data_util.get_amazon_review_dataset(
156split='train',
157batch_size=128,
158shuffle=True,
159drop_remainder=False,
160include_meta=False,
161)
162val_ds = data_util.get_amazon_review_dataset(
163split='id_val',
164batch_size=128,
165shuffle=False,
166drop_remainder=False,
167include_meta=False,
168)
169epochs = 200
170learning_rate = 1e-3
171num_classes = 5
172train_ds_iter = iter(train_ds)
173init_inputs, _ = next(train_ds_iter)
174for _ in train_ds_iter:
175pass
176input_shape = tuple(init_inputs.shape[1:])
177model = model_util.get_roberta_mlp(
178input_shape=input_shape,
179num_classes=num_classes,
180)
181elif dataset == 'otto':
182train_ds = data_util.get_otto_dataset(
183split='train',
184batch_size=128,
185shuffle=True,
186drop_remainder=False,
187)
188val_ds = data_util.get_otto_dataset(
189split='val',
190batch_size=128,
191shuffle=False,
192drop_remainder=False,
193)
194epochs = 200
195learning_rate = 1e-3
196num_classes = 9
197train_ds_iter = iter(train_ds)
198init_inputs, _ = next(train_ds_iter)
199for _ in train_ds_iter:
200pass
201input_shape = tuple(init_inputs.shape[1:])
202model = model_util.get_simple_mlp(
203input_shape=input_shape,
204num_classes=num_classes,
205)
206else:
207raise ValueError(f'Unsupported dataset {dataset}!')
208# Builds model
209model(init_inputs)
210model.summary()
211if dataset == 'color_mnist':
212model.compile(
213optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
214loss=tf.keras.losses.SparseCategoricalCrossentropy(),
215metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
216)
217model.fit(train_ds, epochs=epochs, validation_data=val_ds)
218elif dataset == 'cifar10':
219model.compile(
220optimizer=tf.keras.optimizers.SGD(
221learning_rate=learning_rate, momentum=0.9
222),
223loss=tf.keras.losses.SparseCategoricalCrossentropy(),
224metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
225)
226def scheduler_func(epoch, lr):
227if epoch == 80:
228lr *= 0.1
229elif epoch == 120:
230lr *= 0.1
231elif epoch == 160:
232lr *= 0.1
233elif epoch == 180:
234lr *= 0.5
235return lr
236lr_scheduler = tf.keras.callbacks.LearningRateScheduler(
237scheduler_func
238)
239callbacks = [lr_scheduler]
240model.fit(
241train_ds, epochs=epochs, callbacks=callbacks, validation_data=val_ds
242)
243elif dataset in ['domainnet', 'fmow', 'amazon_review', 'otto']:
244model.compile(
245optimizer=tf.keras.optimizers.Adam(learning_rate=learning_rate),
246loss=tf.keras.losses.SparseCategoricalCrossentropy(),
247metrics=[tf.keras.metrics.SparseCategoricalAccuracy()],
248)
249model.fit(
250train_ds, epochs=epochs, validation_data=val_ds
251)
252model.save_weights(
253os.path.join(save_dir, f'{dataset}', 'checkpoint')
254)
255
256
257if __name__ == '__main__':
258main()
259