google-research

Форк
0
286 строк · 8.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Train and evaluate models for time series forecasting."""
17

18
import argparse
19
import glob
20
import logging
21
import os
22
import time
23

24
from data_loader import TSFDataLoader
25
import models
26
import numpy as np
27
import pandas as pd
28
import tensorflow as tf
29

30
os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3'  # FATAL
31
logging.getLogger('tensorflow').setLevel(logging.FATAL)
32

33

34
def parse_args():
35
  """Parse the arguments for experiment configuration."""
36

37
  parser = argparse.ArgumentParser(
38
      description='TSMixer for Time Series Forecasting'
39
  )
40

41
  # basic config
42
  parser.add_argument('--seed', type=int, default=0, help='random seed')
43
  parser.add_argument(
44
      '--model',
45
      type=str,
46
      default='tsmixer',
47
      help='model name, options: [tsmixer, tsmixer_rev_in]',
48
  )
49

50
  # data loader
51
  parser.add_argument(
52
      '--data',
53
      type=str,
54
      default='weather',
55
      choices=[
56
          'electricity',
57
          'exchange_rate',
58
          'national_illness',
59
          'traffic',
60
          'weather',
61
          'ETTm1',
62
          'ETTm2',
63
          'ETTh1',
64
          'ETTh2',
65
      ],
66
      help='data name',
67
  )
68
  parser.add_argument(
69
      '--feature_type',
70
      type=str,
71
      default='M',
72
      choices=['S', 'M', 'MS'],
73
      help=(
74
          'forecasting task, options:[M, S, MS]; M:multivariate predict'
75
          ' multivariate, S:univariate predict univariate, MS:multivariate'
76
          ' predict univariate'
77
      ),
78
  )
79
  parser.add_argument(
80
      '--target', type=str, default='OT', help='target feature in S or MS task'
81
  )
82
  parser.add_argument(
83
      '--checkpoint_dir',
84
      type=str,
85
      default='./checkpoints/',
86
      help='location of model checkpoints',
87
  )
88
  parser.add_argument(
89
      '--delete_checkpoint',
90
      action='store_true',
91
      help='delete checkpoints after the experiment',
92
  )
93

94
  # forecasting task
95
  parser.add_argument(
96
      '--seq_len', type=int, default=336, help='input sequence length'
97
  )
98
  parser.add_argument(
99
      '--pred_len', type=int, default=96, help='prediction sequence length'
100
  )
101

102
  # model hyperparameter
103
  parser.add_argument(
104
      '--n_block',
105
      type=int,
106
      default=2,
107
      help='number of block for deep architecture',
108
  )
109
  parser.add_argument(
110
      '--ff_dim',
111
      type=int,
112
      default=2048,
113
      help='fully-connected feature dimension',
114
  )
115
  parser.add_argument(
116
      '--dropout', type=float, default=0.05, help='dropout rate'
117
  )
118
  parser.add_argument(
119
      '--norm_type',
120
      type=str,
121
      default='B',
122
      choices=['L', 'B'],
123
      help='LayerNorm or BatchNorm',
124
  )
125
  parser.add_argument(
126
      '--activation',
127
      type=str,
128
      default='relu',
129
      choices=['relu', 'gelu'],
130
      help='Activation function',
131
  )
132
  parser.add_argument(
133
      '--kernel_size', type=int, default=4, help='kernel size for CNN'
134
  )
135
  parser.add_argument(
136
      '--temporal_dim', type=int, default=16, help='temporal feature dimension'
137
  )
138
  parser.add_argument(
139
      '--hidden_dim', type=int, default=64, help='hidden feature dimension'
140
  )
141

142
  # optimization
143
  parser.add_argument(
144
      '--num_workers', type=int, default=10, help='data loader num workers'
145
  )
146
  parser.add_argument(
147
      '--train_epochs', type=int, default=100, help='train epochs'
148
  )
149
  parser.add_argument(
150
      '--batch_size', type=int, default=32, help='batch size of input data'
151
  )
152
  parser.add_argument(
153
      '--learning_rate',
154
      type=float,
155
      default=0.0001,
156
      help='optimizer learning rate',
157
  )
158
  parser.add_argument(
159
      '--patience', type=int, default=5, help='number of epochs to early stop'
160
  )
161

162
  # save results
163
  parser.add_argument(
164
      '--result_path', default='result.csv', help='path to save result'
165
  )
166

167
  args = parser.parse_args()
168

169
  tf.keras.utils.set_random_seed(args.seed)
170

171
  return args
172

173

174
def main():
175
  args = parse_args()
176
  if 'tsmixer' in args.model:
177
    exp_id = f'{args.data}_{args.feature_type}_{args.model}_sl{args.seq_len}_pl{args.pred_len}_lr{args.learning_rate}_nt{args.norm_type}_{args.activation}_nb{args.n_block}_dp{args.dropout}_fd{args.ff_dim}'
178
  elif args.model == 'full_linear':
179
    exp_id = f'{args.data}_{args.feature_type}_{args.model}_sl{args.seq_len}_pl{args.pred_len}_lr{args.learning_rate}'
180
  elif args.model == 'cnn':
181
    exp_id = f'{args.data}_{args.feature_type}_{args.model}_sl{args.seq_len}_pl{args.pred_len}_lr{args.learning_rate}_ks{args.kernel_size}'
182
  else:
183
    raise ValueError(f'Unknown model type: {args.model}')
184

185
  # load datasets
186
  data_loader = TSFDataLoader(
187
      args.data,
188
      args.batch_size,
189
      args.seq_len,
190
      args.pred_len,
191
      args.feature_type,
192
      args.target,
193
  )
194
  train_data = data_loader.get_train()
195
  val_data = data_loader.get_val()
196
  test_data = data_loader.get_test()
197

198
  # train model
199
  if 'tsmixer' in args.model:
200
    build_model = getattr(models, args.model).build_model
201
    model = build_model(
202
        input_shape=(args.seq_len, data_loader.n_feature),
203
        pred_len=args.pred_len,
204
        norm_type=args.norm_type,
205
        activation=args.activation,
206
        dropout=args.dropout,
207
        n_block=args.n_block,
208
        ff_dim=args.ff_dim,
209
        target_slice=data_loader.target_slice,
210
    )
211
  elif args.model == 'full_linear':
212
    model = models.full_linear.Model(
213
        n_channel=data_loader.n_feature,
214
        pred_len=args.pred_len,
215
    )
216
  elif args.model == 'cnn':
217
    model = models.cnn.Model(
218
        n_channel=data_loader.n_feature,
219
        pred_len=args.pred_len,
220
        kernel_size=args.kernel_size,
221
    )
222
  else:
223
    raise ValueError(f'Model not supported: {args.model}')
224

225
  optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)
226
  model.compile(optimizer=optimizer, loss='mse', metrics=['mae'])
227
  checkpoint_path = os.path.join(args.checkpoint_dir, f'{exp_id}_best')
228
  checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(
229
      filepath=checkpoint_path,
230
      verbose=1,
231
      save_best_only=True,
232
      save_weights_only=True,
233
  )
234
  early_stop_callback = tf.keras.callbacks.EarlyStopping(
235
      monitor='val_loss', patience=args.patience
236
  )
237
  start_training_time = time.time()
238
  history = model.fit(
239
      train_data,
240
      epochs=args.train_epochs,
241
      validation_data=val_data,
242
      callbacks=[checkpoint_callback, early_stop_callback],
243
  )
244
  end_training_time = time.time()
245
  elasped_training_time = end_training_time - start_training_time
246
  print(f'Training finished in {elasped_training_time} secconds')
247

248
  # evaluate best model
249
  best_epoch = np.argmin(history.history['val_loss'])
250
  model.load_weights(checkpoint_path)
251
  test_result = model.evaluate(test_data)
252
  if args.delete_checkpoint:
253
    for f in glob.glob(checkpoint_path + '*'):
254
      os.remove(f)
255

256
  # save result to csv
257
  data = {
258
      'data': [args.data],
259
      'model': [args.model],
260
      'seq_len': [args.seq_len],
261
      'pred_len': [args.pred_len],
262
      'lr': [args.learning_rate],
263
      'mse': [test_result[0]],
264
      'mae': [test_result[1]],
265
      'val_mse': [history.history['val_loss'][best_epoch]],
266
      'val_mae': [history.history['val_mae'][best_epoch]],
267
      'train_mse': [history.history['loss'][best_epoch]],
268
      'train_mae': [history.history['mae'][best_epoch]],
269
      'training_time': elasped_training_time,
270
      'norm_type': args.norm_type,
271
      'activation': args.activation,
272
      'n_block': args.n_block,
273
      'dropout': args.dropout,
274
  }
275
  if 'TSMixer' in args.model:
276
    data['ff_dim'] = args.ff_dim
277

278
  df = pd.DataFrame(data)
279
  if os.path.exists(args.result_path):
280
    df.to_csv(args.result_path, mode='a', index=False, header=False)
281
  else:
282
    df.to_csv(args.result_path, mode='w', index=False, header=True)
283

284

285
if __name__ == '__main__':
286
  main()
287

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.