google-research

Форк
0
134 строки · 4.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Load raw data and generate time series dataset."""
17

18
import os
19

20
import numpy as np
21
import pandas as pd
22
from sklearn.preprocessing import StandardScaler
23
import tensorflow as tf
24

25

26
DATA_DIR = 'gs://time_series_datasets'
27
LOCAL_CACHE_DIR = './dataset/'
28

29

30
class TSFDataLoader:
31
  """Generate data loader from raw data."""
32

33
  def __init__(
34
      self, data, batch_size, seq_len, pred_len, feature_type, target='OT'
35
  ):
36
    self.data = data
37
    self.batch_size = batch_size
38
    self.seq_len = seq_len
39
    self.pred_len = pred_len
40
    self.feature_type = feature_type
41
    self.target = target
42
    self.target_slice = slice(0, None)
43

44
    self._read_data()
45

46
  def _read_data(self):
47
    """Load raw data and split datasets."""
48

49
    # copy data from cloud storage if not exists
50
    if not os.path.isdir(LOCAL_CACHE_DIR):
51
      os.mkdir(LOCAL_CACHE_DIR)
52

53
    file_name = self.data + '.csv'
54
    cache_filepath = os.path.join(LOCAL_CACHE_DIR, file_name)
55
    if not os.path.isfile(cache_filepath):
56
      tf.io.gfile.copy(
57
          os.path.join(DATA_DIR, file_name), cache_filepath, overwrite=True
58
      )
59

60
    df_raw = pd.read_csv(cache_filepath)
61

62
    # S: univariate-univariate, M: multivariate-multivariate, MS:
63
    # multivariate-univariate
64
    df = df_raw.set_index('date')
65
    if self.feature_type == 'S':
66
      df = df[[self.target]]
67
    elif self.feature_type == 'MS':
68
      target_idx = df.columns.get_loc(self.target)
69
      self.target_slice = slice(target_idx, target_idx + 1)
70

71
    # split train/valid/test
72
    n = len(df)
73
    if self.data.startswith('ETTm'):
74
      train_end = 12 * 30 * 24 * 4
75
      val_end = train_end + 4 * 30 * 24 * 4
76
      test_end = val_end + 4 * 30 * 24 * 4
77
    elif self.data.startswith('ETTh'):
78
      train_end = 12 * 30 * 24
79
      val_end = train_end + 4 * 30 * 24
80
      test_end = val_end + 4 * 30 * 24
81
    else:
82
      train_end = int(n * 0.7)
83
      val_end = n - int(n * 0.2)
84
      test_end = n
85
    train_df = df[:train_end]
86
    val_df = df[train_end - self.seq_len : val_end]
87
    test_df = df[val_end - self.seq_len : test_end]
88

89
    # standardize by training set
90
    self.scaler = StandardScaler()
91
    self.scaler.fit(train_df.values)
92

93
    def scale_df(df, scaler):
94
      data = scaler.transform(df.values)
95
      return pd.DataFrame(data, index=df.index, columns=df.columns)
96

97
    self.train_df = scale_df(train_df, self.scaler)
98
    self.val_df = scale_df(val_df, self.scaler)
99
    self.test_df = scale_df(test_df, self.scaler)
100
    self.n_feature = self.train_df.shape[-1]
101

102
  def _split_window(self, data):
103
    inputs = data[:, : self.seq_len, :]
104
    labels = data[:, self.seq_len :, self.target_slice]
105
    # Slicing doesn't preserve static shape information, so set the shapes
106
    # manually. This way the `tf.data.Datasets` are easier to inspect.
107
    inputs.set_shape([None, self.seq_len, None])
108
    labels.set_shape([None, self.pred_len, None])
109
    return inputs, labels
110

111
  def _make_dataset(self, data, shuffle=True):
112
    data = np.array(data, dtype=np.float32)
113
    ds = tf.keras.utils.timeseries_dataset_from_array(
114
        data=data,
115
        targets=None,
116
        sequence_length=(self.seq_len + self.pred_len),
117
        sequence_stride=1,
118
        shuffle=shuffle,
119
        batch_size=self.batch_size,
120
    )
121
    ds = ds.map(self._split_window)
122
    return ds
123

124
  def inverse_transform(self, data):
125
    return self.scaler.inverse_transform(data)
126

127
  def get_train(self, shuffle=True):
128
    return self._make_dataset(self.train_df, shuffle=shuffle)
129

130
  def get_val(self):
131
    return self._make_dataset(self.val_df, shuffle=False)
132

133
  def get_test(self):
134
    return self._make_dataset(self.test_df, shuffle=False)
135

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.