google-research
286 строк · 8.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Train and evaluate models for time series forecasting."""
17
18import argparse19import glob20import logging21import os22import time23
24from data_loader import TSFDataLoader25import models26import numpy as np27import pandas as pd28import tensorflow as tf29
30os.environ['TF_CPP_MIN_LOG_LEVEL'] = '3' # FATAL31logging.getLogger('tensorflow').setLevel(logging.FATAL)32
33
34def parse_args():35"""Parse the arguments for experiment configuration."""36
37parser = argparse.ArgumentParser(38description='TSMixer for Time Series Forecasting'39)40
41# basic config42parser.add_argument('--seed', type=int, default=0, help='random seed')43parser.add_argument(44'--model',45type=str,46default='tsmixer',47help='model name, options: [tsmixer, tsmixer_rev_in]',48)49
50# data loader51parser.add_argument(52'--data',53type=str,54default='weather',55choices=[56'electricity',57'exchange_rate',58'national_illness',59'traffic',60'weather',61'ETTm1',62'ETTm2',63'ETTh1',64'ETTh2',65],66help='data name',67)68parser.add_argument(69'--feature_type',70type=str,71default='M',72choices=['S', 'M', 'MS'],73help=(74'forecasting task, options:[M, S, MS]; M:multivariate predict'75' multivariate, S:univariate predict univariate, MS:multivariate'76' predict univariate'77),78)79parser.add_argument(80'--target', type=str, default='OT', help='target feature in S or MS task'81)82parser.add_argument(83'--checkpoint_dir',84type=str,85default='./checkpoints/',86help='location of model checkpoints',87)88parser.add_argument(89'--delete_checkpoint',90action='store_true',91help='delete checkpoints after the experiment',92)93
94# forecasting task95parser.add_argument(96'--seq_len', type=int, default=336, help='input sequence length'97)98parser.add_argument(99'--pred_len', type=int, default=96, help='prediction sequence length'100)101
102# model hyperparameter103parser.add_argument(104'--n_block',105type=int,106default=2,107help='number of block for deep architecture',108)109parser.add_argument(110'--ff_dim',111type=int,112default=2048,113help='fully-connected feature dimension',114)115parser.add_argument(116'--dropout', type=float, default=0.05, help='dropout rate'117)118parser.add_argument(119'--norm_type',120type=str,121default='B',122choices=['L', 'B'],123help='LayerNorm or BatchNorm',124)125parser.add_argument(126'--activation',127type=str,128default='relu',129choices=['relu', 'gelu'],130help='Activation function',131)132parser.add_argument(133'--kernel_size', type=int, default=4, help='kernel size for CNN'134)135parser.add_argument(136'--temporal_dim', type=int, default=16, help='temporal feature dimension'137)138parser.add_argument(139'--hidden_dim', type=int, default=64, help='hidden feature dimension'140)141
142# optimization143parser.add_argument(144'--num_workers', type=int, default=10, help='data loader num workers'145)146parser.add_argument(147'--train_epochs', type=int, default=100, help='train epochs'148)149parser.add_argument(150'--batch_size', type=int, default=32, help='batch size of input data'151)152parser.add_argument(153'--learning_rate',154type=float,155default=0.0001,156help='optimizer learning rate',157)158parser.add_argument(159'--patience', type=int, default=5, help='number of epochs to early stop'160)161
162# save results163parser.add_argument(164'--result_path', default='result.csv', help='path to save result'165)166
167args = parser.parse_args()168
169tf.keras.utils.set_random_seed(args.seed)170
171return args172
173
174def main():175args = parse_args()176if 'tsmixer' in args.model:177exp_id = f'{args.data}_{args.feature_type}_{args.model}_sl{args.seq_len}_pl{args.pred_len}_lr{args.learning_rate}_nt{args.norm_type}_{args.activation}_nb{args.n_block}_dp{args.dropout}_fd{args.ff_dim}'178elif args.model == 'full_linear':179exp_id = f'{args.data}_{args.feature_type}_{args.model}_sl{args.seq_len}_pl{args.pred_len}_lr{args.learning_rate}'180elif args.model == 'cnn':181exp_id = f'{args.data}_{args.feature_type}_{args.model}_sl{args.seq_len}_pl{args.pred_len}_lr{args.learning_rate}_ks{args.kernel_size}'182else:183raise ValueError(f'Unknown model type: {args.model}')184
185# load datasets186data_loader = TSFDataLoader(187args.data,188args.batch_size,189args.seq_len,190args.pred_len,191args.feature_type,192args.target,193)194train_data = data_loader.get_train()195val_data = data_loader.get_val()196test_data = data_loader.get_test()197
198# train model199if 'tsmixer' in args.model:200build_model = getattr(models, args.model).build_model201model = build_model(202input_shape=(args.seq_len, data_loader.n_feature),203pred_len=args.pred_len,204norm_type=args.norm_type,205activation=args.activation,206dropout=args.dropout,207n_block=args.n_block,208ff_dim=args.ff_dim,209target_slice=data_loader.target_slice,210)211elif args.model == 'full_linear':212model = models.full_linear.Model(213n_channel=data_loader.n_feature,214pred_len=args.pred_len,215)216elif args.model == 'cnn':217model = models.cnn.Model(218n_channel=data_loader.n_feature,219pred_len=args.pred_len,220kernel_size=args.kernel_size,221)222else:223raise ValueError(f'Model not supported: {args.model}')224
225optimizer = tf.keras.optimizers.Adam(learning_rate=args.learning_rate)226model.compile(optimizer=optimizer, loss='mse', metrics=['mae'])227checkpoint_path = os.path.join(args.checkpoint_dir, f'{exp_id}_best')228checkpoint_callback = tf.keras.callbacks.ModelCheckpoint(229filepath=checkpoint_path,230verbose=1,231save_best_only=True,232save_weights_only=True,233)234early_stop_callback = tf.keras.callbacks.EarlyStopping(235monitor='val_loss', patience=args.patience236)237start_training_time = time.time()238history = model.fit(239train_data,240epochs=args.train_epochs,241validation_data=val_data,242callbacks=[checkpoint_callback, early_stop_callback],243)244end_training_time = time.time()245elasped_training_time = end_training_time - start_training_time246print(f'Training finished in {elasped_training_time} secconds')247
248# evaluate best model249best_epoch = np.argmin(history.history['val_loss'])250model.load_weights(checkpoint_path)251test_result = model.evaluate(test_data)252if args.delete_checkpoint:253for f in glob.glob(checkpoint_path + '*'):254os.remove(f)255
256# save result to csv257data = {258'data': [args.data],259'model': [args.model],260'seq_len': [args.seq_len],261'pred_len': [args.pred_len],262'lr': [args.learning_rate],263'mse': [test_result[0]],264'mae': [test_result[1]],265'val_mse': [history.history['val_loss'][best_epoch]],266'val_mae': [history.history['val_mae'][best_epoch]],267'train_mse': [history.history['loss'][best_epoch]],268'train_mae': [history.history['mae'][best_epoch]],269'training_time': elasped_training_time,270'norm_type': args.norm_type,271'activation': args.activation,272'n_block': args.n_block,273'dropout': args.dropout,274}275if 'TSMixer' in args.model:276data['ff_dim'] = args.ff_dim277
278df = pd.DataFrame(data)279if os.path.exists(args.result_path):280df.to_csv(args.result_path, mode='a', index=False, header=False)281else:282df.to_csv(args.result_path, mode='w', index=False, header=True)283
284
285if __name__ == '__main__':286main()287