google-research
134 строки · 4.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Load raw data and generate time series dataset."""
17
18import os
19
20import numpy as np
21import pandas as pd
22from sklearn.preprocessing import StandardScaler
23import tensorflow as tf
24
25
26DATA_DIR = 'gs://time_series_datasets'
27LOCAL_CACHE_DIR = './dataset/'
28
29
30class TSFDataLoader:
31"""Generate data loader from raw data."""
32
33def __init__(
34self, data, batch_size, seq_len, pred_len, feature_type, target='OT'
35):
36self.data = data
37self.batch_size = batch_size
38self.seq_len = seq_len
39self.pred_len = pred_len
40self.feature_type = feature_type
41self.target = target
42self.target_slice = slice(0, None)
43
44self._read_data()
45
46def _read_data(self):
47"""Load raw data and split datasets."""
48
49# copy data from cloud storage if not exists
50if not os.path.isdir(LOCAL_CACHE_DIR):
51os.mkdir(LOCAL_CACHE_DIR)
52
53file_name = self.data + '.csv'
54cache_filepath = os.path.join(LOCAL_CACHE_DIR, file_name)
55if not os.path.isfile(cache_filepath):
56tf.io.gfile.copy(
57os.path.join(DATA_DIR, file_name), cache_filepath, overwrite=True
58)
59
60df_raw = pd.read_csv(cache_filepath)
61
62# S: univariate-univariate, M: multivariate-multivariate, MS:
63# multivariate-univariate
64df = df_raw.set_index('date')
65if self.feature_type == 'S':
66df = df[[self.target]]
67elif self.feature_type == 'MS':
68target_idx = df.columns.get_loc(self.target)
69self.target_slice = slice(target_idx, target_idx + 1)
70
71# split train/valid/test
72n = len(df)
73if self.data.startswith('ETTm'):
74train_end = 12 * 30 * 24 * 4
75val_end = train_end + 4 * 30 * 24 * 4
76test_end = val_end + 4 * 30 * 24 * 4
77elif self.data.startswith('ETTh'):
78train_end = 12 * 30 * 24
79val_end = train_end + 4 * 30 * 24
80test_end = val_end + 4 * 30 * 24
81else:
82train_end = int(n * 0.7)
83val_end = n - int(n * 0.2)
84test_end = n
85train_df = df[:train_end]
86val_df = df[train_end - self.seq_len : val_end]
87test_df = df[val_end - self.seq_len : test_end]
88
89# standardize by training set
90self.scaler = StandardScaler()
91self.scaler.fit(train_df.values)
92
93def scale_df(df, scaler):
94data = scaler.transform(df.values)
95return pd.DataFrame(data, index=df.index, columns=df.columns)
96
97self.train_df = scale_df(train_df, self.scaler)
98self.val_df = scale_df(val_df, self.scaler)
99self.test_df = scale_df(test_df, self.scaler)
100self.n_feature = self.train_df.shape[-1]
101
102def _split_window(self, data):
103inputs = data[:, : self.seq_len, :]
104labels = data[:, self.seq_len :, self.target_slice]
105# Slicing doesn't preserve static shape information, so set the shapes
106# manually. This way the `tf.data.Datasets` are easier to inspect.
107inputs.set_shape([None, self.seq_len, None])
108labels.set_shape([None, self.pred_len, None])
109return inputs, labels
110
111def _make_dataset(self, data, shuffle=True):
112data = np.array(data, dtype=np.float32)
113ds = tf.keras.utils.timeseries_dataset_from_array(
114data=data,
115targets=None,
116sequence_length=(self.seq_len + self.pred_len),
117sequence_stride=1,
118shuffle=shuffle,
119batch_size=self.batch_size,
120)
121ds = ds.map(self._split_window)
122return ds
123
124def inverse_transform(self, data):
125return self.scaler.inverse_transform(data)
126
127def get_train(self, shuffle=True):
128return self._make_dataset(self.train_df, shuffle=shuffle)
129
130def get_val(self):
131return self._make_dataset(self.val_df, shuffle=False)
132
133def get_test(self):
134return self._make_dataset(self.test_df, shuffle=False)
135