google-research

Форк
0
177 строк · 6.1 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Defines the WindowedSequenceDataset."""
17

18
# allow capital letter names for dimensions to improve clarity (e.g. N, T, D)
19
# pylint: disable=invalid-name
20

21
import numpy as np
22
import torch
23
import torch.utils.data
24
from utils.preprocessing_utils import compute_windows
25
from utils.preprocessing_utils import preprocess_numpy_data
26

27

28
class WindowedSequenceDataset(torch.utils.data.dataset.Dataset):
29
  """WindowedSequenceDataset object.
30

31
  Pytorch dataset object which formats sequences to prepare them as input
32
  to neural sequence lib (e.g. LSTM).
33
  """
34

35
  def __init__(
36
      self,
37
      x,
38
      lengths,
39
      forecasting_horizon,
40
      minmax_scaling,
41
      input_window_size,
42
      default_nan_value = 1e15,
43
      per_series_scaling=True,
44
      target_dims=(0,),
45
  ):
46
    """Initializes the WindowedSequenceDataset object.
47

48
    Takes in a numpy array of sequences, and pre-processes this data into a
49
    tensors
50
    ready for pytorch machine learning lib to learn from and predict on.
51

52
    Args:
53
      x: numpy array (N x T) of all sequences
54
      lengths: numpy array (N) containing the length of each sequence
55
      forecasting_horizon: desired number of timesteps into the future for
56
        lib to predict
57
      minmax_scaling: whether to scale the time series to be between 0 and 1
58
      input_window_size: number of past timesteps for model to use as input
59
      default_nan_value: value to impute nans with
60
      per_series_scaling: whether to scale differently for each series
61
      target_dims: dimensions containing target values
62
    """
63
    super().__init__()
64
    self.forecasting_horizon = forecasting_horizon
65
    self.minmax_scaling = minmax_scaling
66
    self.input_window_size = input_window_size
67
    self.default_nan_value = default_nan_value
68
    self.x, self.x_offset, self.x_scale = preprocess_numpy_data(
69
        x,
70
        minmax_scaling,
71
        default_nan_value=default_nan_value,
72
        per_series_scaling=per_series_scaling,
73
        ignore_dims=[
74
            i for i in range(x.shape[-1]) if i not in list(target_dims)
75
        ],
76
    )
77

78
    self.default_nan_value = default_nan_value
79
    self.lengths = torch.from_numpy(lengths)
80

81
    assert len(self.x_scale.shape) == 2
82
    assert len(self.x_offset.shape) == 2
83

84
  def _compute_inputs_and_targets(
85
      self, x, lengths, default_nan_value
86
  ):
87
    """Format x tensor into tensor of inputs and targets.
88

89
    Args:
90
      x: tensor (N x T x D) containing all time series
91
      lengths: tensor (N) containing the length of each series
92
      default_nan_value: value to impute nans with
93

94
    Returns:
95
      inputs: tensor, N x T_sliding (< T) x input_window_size x D, containing
96
      inputs
97
      targets: tensor, N x T_sliding (< T) x forecasting_horizon x
98
      len(target_dims), containing targets
99
      input_times: tensor of timestamps corresponding to each entry of inputs
100
      target_times: tensor of timestamps corresponding to each entry of targets
101
      target_times_mask: binary tensor, N x T_sliding (< T) x
102
      forecasting_horizon x T x D,
103
        where each entry (i,j,k,l,m) indicates whether the (i,j,k,l) entry in
104
        target_times is less than or equal to l + 1  (where l is zero-indexed).
105
        This is useful when trying to aggregate across all previous forecast
106
        errors (i.e.) all errors before each timepoint.
107
    """
108
    N, T, D = x.shape
109

110
    window = self.input_window_size
111
    forecasting_horizon = self.forecasting_horizon
112

113
    inputs, targets = compute_windows(x, window, forecasting_horizon)
114
    assert inputs.shape[0] == N
115
    assert inputs.shape[1] == T - window + 1
116
    assert inputs.shape[2] == window
117
    assert inputs.shape[3] == D
118

119
    assert targets.shape[0] == N
120
    assert targets.shape[1] == T - window + 1
121
    assert targets.shape[2] == forecasting_horizon
122
    assert targets.shape[3] == D
123

124
    # compute timesteps associated with each unrolled input and target
125
    times = np.ones((N, T)) * default_nan_value
126
    for i, l in enumerate(lengths):
127
      times[i, :l] = np.arange(1, l + 1)
128
    times = torch.from_numpy(times)
129
    input_times, target_times = compute_windows(
130
        times, window, forecasting_horizon
131
    )
132

133
    # compute masks associated with each timestep
134
    repeat_ts = target_times.unsqueeze(-1).repeat(1, 1, 1, 1, T + 1)
135
    target_times_mask = torch.zeros(repeat_ts.shape)
136
    for t in range(1, T + 2):
137
      idx = t - 1
138
      target_times_mask[:, :, :, :, idx] = (
139
          repeat_ts[:, :, :, :, idx] <= t
140
      ).float()
141
    target_times_mask = target_times_mask[
142
        :, :, :, :, window:
143
    ]  # starting e.g. timepoint 37, whether time is passed
144
    return inputs, targets, input_times, target_times, target_times_mask
145

146
  def __getitem__(self, index):
147
    inputs, targets, input_times, target_times, target_times_mask = (
148
        self._compute_inputs_and_targets(
149
            self.x[index : index + 1],
150
            self.lengths[index : index + 1],
151
            default_nan_value=self.default_nan_value,
152
        )
153
    )
154

155
    assert inputs.shape[0] == 1
156
    assert targets.shape[0] == 1
157
    assert input_times.shape[0] == 1
158
    assert target_times.shape[0] == 1
159

160
    d = {
161
        'x': self.x[index],
162
        'x_offset': self.x_offset,
163
        'x_scale': self.x_scale,
164
        'inputs': inputs[0],
165
        'targets': targets[0],
166
        'lengths': self.lengths[index],
167
        'input_times': input_times[0],
168
        'target_times': target_times[0],
169
        'target_times_mask': target_times_mask[0],
170
    }
171
    if self.x_offset.shape[0] > 1:
172
      d['x_offset'] = self.x_offset[index]
173
      d['x_scale'] = self.x_scale[index]
174
    return d
175

176
  def __len__(self):
177
    return len(self.x)
178

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.