google-research

Форк
0
199 строк · 6.5 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
# coding=utf-8
17
# Copyright 2022 The Google Research Authors.
18
#
19
# Licensed under the Apache License, Version 2.0 (the "License");
20
# you may not use this file except in compliance with the License.
21
# You may obtain a copy of the License at
22
#
23
#     http://www.apache.org/licenses/LICENSE-2.0
24
#
25
# Unless required by applicable law or agreed to in writing, software
26
# distributed under the License is distributed on an "AS IS" BASIS,
27
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28
# See the License for the specific language governing permissions and
29
# limitations under the License.
30
"""Data processing and loading classes for time series datasets.
31

32
Defines data-related classes for data preprocessing, data splitting and
33
dataset loading.
34
"""
35

36
from pytorch_lightning import LightningDataModule
37
import torch
38
from torch.utils.data import DataLoader
39
from torch.utils.data import Dataset
40

41

42
class StandardScaler():
43
  """Applies standardization transformation to data."""
44

45
  def __init__(self, mean, std):
46
    """Instantiates the data transformation.
47

48
    Args:
49
      mean: mean to be subtracted from the data.
50
      std: standard deviation to be divided.
51
    """
52
    self.mean = torch.as_tensor(mean, dtype=torch.float)
53
    self.std = torch.as_tensor(std, dtype=torch.float)
54

55
  def __call__(self, inp_seq):
56
    """Applies the forward standardization transformation to input data.
57

58
    Args:
59
      inp_seq: input data, must have the same shape as self.mean and self.std.
60

61
    Returns:
62
      new_inp: transformed data.
63
    """
64
    new_inp = (inp_seq - self.mean) / self.std
65
    return new_inp
66

67
  def inverse_transform(self, inp_seq):
68
    """Applies the inverse standardization transformation to input data.
69

70
    Usually used to transform the predicted time series from the neural
71
    network model to real time series.
72

73
    Args:
74
      inp_seq: input data, must have the same shape as self.mean and self.std.
75

76
    Returns:
77
      new_inp: transformed data.
78
    """
79
    new_inp = inp_seq * self.std + self.mean
80
    return new_inp
81

82

83
class TSDataset(Dataset):
84
  """PyTorch dataset interface class for time series data."""
85

86
  def __init__(self, x, args, transform=None):
87
    """Instantiates the dataset interface.
88

89
    Args:
90
      x: the numpy array of time series data, with the shape (seq_len x
91
        num_nodes x num_features).
92
      args: python argparse.ArgumentParser class, we only use data-related
93
        arguments here.
94
      transform: the transformation applied to the data, we use StandardScaler
95
        here.
96
    """
97
    self.x = torch.as_tensor(x, dtype=torch.float)
98
    self.input_len = args.input_len
99
    self.output_len = args.output_len
100
    self.input_dim = args.input_dim
101
    self.output_dim = args.output_dim
102
    self.window_len = self.input_len + self.output_len
103
    self.transform = transform
104

105
  def __getitem__(self, index):
106
    """Gets the time series data sample corresponding to the index.
107

108
    Args:
109
      index: data sample index.
110

111
    Returns:
112
      Tuple of (input time series window, ground truth of output time
113
        series window).
114
    """
115
    input_seq = self.x[index:index + self.input_len]
116
    input_seq = input_seq[Ellipsis, :self.input_dim]
117
    output_seq = self.x[index + self.input_len:index + self.window_len]
118
    output_seq = output_seq[Ellipsis, :self.output_dim]
119

120
    if self.transform:
121
      input_seq = self.transform(input_seq)
122
      output_seq = self.transform(output_seq)
123

124
    return (input_seq, output_seq)
125

126
  def __len__(self):
127
    """Returns the total number of data samples in the dataset."""
128
    num_samples = self.x.shape[0] - self.window_len
129
    return num_samples
130

131

132
class DataModule(LightningDataModule):
133
  """PyTorch Lightning data module class for time series data."""
134

135
  def __init__(self, data, args):
136
    """Instantiates the data module.
137

138
    Args:
139
      data: the numpy array of time series data, with the shape (seq_len x
140
        num_nodes x num_features).
141
      args: python argparse.ArgumentParser class.
142
    """
143
    super().__init__()
144
    self.data = data  # [seq_len, num_nodes, num_features]
145
    self.args = args
146
    self.batch_size = args.batch_size
147
    self.num_workers = args.num_workers
148

149
  def setup(self):
150
    """Splits the data and defines the preprocessing transformation.
151

152
    Splits the data to train/test/val set with specified ratios in
153
    self.args.splits, and defines data preprocessing transformation.
154
    """
155
    num_samples = self.data.shape[0]
156

157
    num_train = round(num_samples * self.args.splits[0])
158
    num_test = round(num_samples * self.args.splits[1])
159
    num_val = num_samples - num_train - num_test
160

161
    self.x_train = self.data[:num_train]  # train series
162
    self.x_val = self.data[num_train:num_train + num_val]  # val_series
163
    self.x_test = self.data[-num_test:]  # test_series
164

165
    self.min_vals = self.x_train.min(axis=(0, 1), keepdims=True)
166
    self.max_vals = self.x_train.max(axis=(0, 1), keepdims=True)
167
    self.mean = self.x_train.mean(axis=(0, 1), keepdims=True)
168
    self.std = self.x_train.std(axis=(0, 1), keepdims=True)
169
    self.transform = StandardScaler(self.mean, self.std)
170

171
  def train_dataloader(self):
172
    """Returns the training data loader."""
173
    dataset = TSDataset(self.x_train, self.args, transform=self.transform)
174
    loader = DataLoader(
175
        dataset,
176
        batch_size=self.batch_size,
177
        num_workers=self.num_workers,
178
        shuffle=True)
179
    return loader
180

181
  def val_dataloader(self):
182
    """Returns the validation data loader."""
183
    dataset = TSDataset(self.x_val, self.args, transform=self.transform)
184
    loader = DataLoader(
185
        dataset,
186
        batch_size=self.batch_size,
187
        num_workers=self.num_workers,
188
        shuffle=False)
189
    return loader
190

191
  def test_dataloader(self):
192
    """Returns the testing data loader."""
193
    dataset = TSDataset(self.x_test, self.args, transform=self.transform)
194
    loader = DataLoader(
195
        dataset,
196
        batch_size=self.batch_size,
197
        num_workers=self.num_workers,
198
        shuffle=False)
199
    return loader
200

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.