google-research
199 строк · 6.5 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# coding=utf-8
17# Copyright 2022 The Google Research Authors.
18#
19# Licensed under the Apache License, Version 2.0 (the "License");
20# you may not use this file except in compliance with the License.
21# You may obtain a copy of the License at
22#
23# http://www.apache.org/licenses/LICENSE-2.0
24#
25# Unless required by applicable law or agreed to in writing, software
26# distributed under the License is distributed on an "AS IS" BASIS,
27# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28# See the License for the specific language governing permissions and
29# limitations under the License.
30"""Data processing and loading classes for time series datasets.
31
32Defines data-related classes for data preprocessing, data splitting and
33dataset loading.
34"""
35
36from pytorch_lightning import LightningDataModule
37import torch
38from torch.utils.data import DataLoader
39from torch.utils.data import Dataset
40
41
42class StandardScaler():
43"""Applies standardization transformation to data."""
44
45def __init__(self, mean, std):
46"""Instantiates the data transformation.
47
48Args:
49mean: mean to be subtracted from the data.
50std: standard deviation to be divided.
51"""
52self.mean = torch.as_tensor(mean, dtype=torch.float)
53self.std = torch.as_tensor(std, dtype=torch.float)
54
55def __call__(self, inp_seq):
56"""Applies the forward standardization transformation to input data.
57
58Args:
59inp_seq: input data, must have the same shape as self.mean and self.std.
60
61Returns:
62new_inp: transformed data.
63"""
64new_inp = (inp_seq - self.mean) / self.std
65return new_inp
66
67def inverse_transform(self, inp_seq):
68"""Applies the inverse standardization transformation to input data.
69
70Usually used to transform the predicted time series from the neural
71network model to real time series.
72
73Args:
74inp_seq: input data, must have the same shape as self.mean and self.std.
75
76Returns:
77new_inp: transformed data.
78"""
79new_inp = inp_seq * self.std + self.mean
80return new_inp
81
82
83class TSDataset(Dataset):
84"""PyTorch dataset interface class for time series data."""
85
86def __init__(self, x, args, transform=None):
87"""Instantiates the dataset interface.
88
89Args:
90x: the numpy array of time series data, with the shape (seq_len x
91num_nodes x num_features).
92args: python argparse.ArgumentParser class, we only use data-related
93arguments here.
94transform: the transformation applied to the data, we use StandardScaler
95here.
96"""
97self.x = torch.as_tensor(x, dtype=torch.float)
98self.input_len = args.input_len
99self.output_len = args.output_len
100self.input_dim = args.input_dim
101self.output_dim = args.output_dim
102self.window_len = self.input_len + self.output_len
103self.transform = transform
104
105def __getitem__(self, index):
106"""Gets the time series data sample corresponding to the index.
107
108Args:
109index: data sample index.
110
111Returns:
112Tuple of (input time series window, ground truth of output time
113series window).
114"""
115input_seq = self.x[index:index + self.input_len]
116input_seq = input_seq[Ellipsis, :self.input_dim]
117output_seq = self.x[index + self.input_len:index + self.window_len]
118output_seq = output_seq[Ellipsis, :self.output_dim]
119
120if self.transform:
121input_seq = self.transform(input_seq)
122output_seq = self.transform(output_seq)
123
124return (input_seq, output_seq)
125
126def __len__(self):
127"""Returns the total number of data samples in the dataset."""
128num_samples = self.x.shape[0] - self.window_len
129return num_samples
130
131
132class DataModule(LightningDataModule):
133"""PyTorch Lightning data module class for time series data."""
134
135def __init__(self, data, args):
136"""Instantiates the data module.
137
138Args:
139data: the numpy array of time series data, with the shape (seq_len x
140num_nodes x num_features).
141args: python argparse.ArgumentParser class.
142"""
143super().__init__()
144self.data = data # [seq_len, num_nodes, num_features]
145self.args = args
146self.batch_size = args.batch_size
147self.num_workers = args.num_workers
148
149def setup(self):
150"""Splits the data and defines the preprocessing transformation.
151
152Splits the data to train/test/val set with specified ratios in
153self.args.splits, and defines data preprocessing transformation.
154"""
155num_samples = self.data.shape[0]
156
157num_train = round(num_samples * self.args.splits[0])
158num_test = round(num_samples * self.args.splits[1])
159num_val = num_samples - num_train - num_test
160
161self.x_train = self.data[:num_train] # train series
162self.x_val = self.data[num_train:num_train + num_val] # val_series
163self.x_test = self.data[-num_test:] # test_series
164
165self.min_vals = self.x_train.min(axis=(0, 1), keepdims=True)
166self.max_vals = self.x_train.max(axis=(0, 1), keepdims=True)
167self.mean = self.x_train.mean(axis=(0, 1), keepdims=True)
168self.std = self.x_train.std(axis=(0, 1), keepdims=True)
169self.transform = StandardScaler(self.mean, self.std)
170
171def train_dataloader(self):
172"""Returns the training data loader."""
173dataset = TSDataset(self.x_train, self.args, transform=self.transform)
174loader = DataLoader(
175dataset,
176batch_size=self.batch_size,
177num_workers=self.num_workers,
178shuffle=True)
179return loader
180
181def val_dataloader(self):
182"""Returns the validation data loader."""
183dataset = TSDataset(self.x_val, self.args, transform=self.transform)
184loader = DataLoader(
185dataset,
186batch_size=self.batch_size,
187num_workers=self.num_workers,
188shuffle=False)
189return loader
190
191def test_dataloader(self):
192"""Returns the testing data loader."""
193dataset = TSDataset(self.x_test, self.args, transform=self.transform)
194loader = DataLoader(
195dataset,
196batch_size=self.batch_size,
197num_workers=self.num_workers,
198shuffle=False)
199return loader
200