google-research
177 строк · 6.1 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Defines the WindowedSequenceDataset."""
17
18# allow capital letter names for dimensions to improve clarity (e.g. N, T, D)
19# pylint: disable=invalid-name
20
21import numpy as np22import torch23import torch.utils.data24from utils.preprocessing_utils import compute_windows25from utils.preprocessing_utils import preprocess_numpy_data26
27
28class WindowedSequenceDataset(torch.utils.data.dataset.Dataset):29"""WindowedSequenceDataset object.30
31Pytorch dataset object which formats sequences to prepare them as input
32to neural sequence lib (e.g. LSTM).
33"""
34
35def __init__(36self,37x,38lengths,39forecasting_horizon,40minmax_scaling,41input_window_size,42default_nan_value = 1e15,43per_series_scaling=True,44target_dims=(0,),45):46"""Initializes the WindowedSequenceDataset object.47
48Takes in a numpy array of sequences, and pre-processes this data into a
49tensors
50ready for pytorch machine learning lib to learn from and predict on.
51
52Args:
53x: numpy array (N x T) of all sequences
54lengths: numpy array (N) containing the length of each sequence
55forecasting_horizon: desired number of timesteps into the future for
56lib to predict
57minmax_scaling: whether to scale the time series to be between 0 and 1
58input_window_size: number of past timesteps for model to use as input
59default_nan_value: value to impute nans with
60per_series_scaling: whether to scale differently for each series
61target_dims: dimensions containing target values
62"""
63super().__init__()64self.forecasting_horizon = forecasting_horizon65self.minmax_scaling = minmax_scaling66self.input_window_size = input_window_size67self.default_nan_value = default_nan_value68self.x, self.x_offset, self.x_scale = preprocess_numpy_data(69x,70minmax_scaling,71default_nan_value=default_nan_value,72per_series_scaling=per_series_scaling,73ignore_dims=[74i for i in range(x.shape[-1]) if i not in list(target_dims)75],76)77
78self.default_nan_value = default_nan_value79self.lengths = torch.from_numpy(lengths)80
81assert len(self.x_scale.shape) == 282assert len(self.x_offset.shape) == 283
84def _compute_inputs_and_targets(85self, x, lengths, default_nan_value86):87"""Format x tensor into tensor of inputs and targets.88
89Args:
90x: tensor (N x T x D) containing all time series
91lengths: tensor (N) containing the length of each series
92default_nan_value: value to impute nans with
93
94Returns:
95inputs: tensor, N x T_sliding (< T) x input_window_size x D, containing
96inputs
97targets: tensor, N x T_sliding (< T) x forecasting_horizon x
98len(target_dims), containing targets
99input_times: tensor of timestamps corresponding to each entry of inputs
100target_times: tensor of timestamps corresponding to each entry of targets
101target_times_mask: binary tensor, N x T_sliding (< T) x
102forecasting_horizon x T x D,
103where each entry (i,j,k,l,m) indicates whether the (i,j,k,l) entry in
104target_times is less than or equal to l + 1 (where l is zero-indexed).
105This is useful when trying to aggregate across all previous forecast
106errors (i.e.) all errors before each timepoint.
107"""
108N, T, D = x.shape109
110window = self.input_window_size111forecasting_horizon = self.forecasting_horizon112
113inputs, targets = compute_windows(x, window, forecasting_horizon)114assert inputs.shape[0] == N115assert inputs.shape[1] == T - window + 1116assert inputs.shape[2] == window117assert inputs.shape[3] == D118
119assert targets.shape[0] == N120assert targets.shape[1] == T - window + 1121assert targets.shape[2] == forecasting_horizon122assert targets.shape[3] == D123
124# compute timesteps associated with each unrolled input and target125times = np.ones((N, T)) * default_nan_value126for i, l in enumerate(lengths):127times[i, :l] = np.arange(1, l + 1)128times = torch.from_numpy(times)129input_times, target_times = compute_windows(130times, window, forecasting_horizon131)132
133# compute masks associated with each timestep134repeat_ts = target_times.unsqueeze(-1).repeat(1, 1, 1, 1, T + 1)135target_times_mask = torch.zeros(repeat_ts.shape)136for t in range(1, T + 2):137idx = t - 1138target_times_mask[:, :, :, :, idx] = (139repeat_ts[:, :, :, :, idx] <= t140).float()141target_times_mask = target_times_mask[142:, :, :, :, window:143] # starting e.g. timepoint 37, whether time is passed144return inputs, targets, input_times, target_times, target_times_mask145
146def __getitem__(self, index):147inputs, targets, input_times, target_times, target_times_mask = (148self._compute_inputs_and_targets(149self.x[index : index + 1],150self.lengths[index : index + 1],151default_nan_value=self.default_nan_value,152)153)154
155assert inputs.shape[0] == 1156assert targets.shape[0] == 1157assert input_times.shape[0] == 1158assert target_times.shape[0] == 1159
160d = {161'x': self.x[index],162'x_offset': self.x_offset,163'x_scale': self.x_scale,164'inputs': inputs[0],165'targets': targets[0],166'lengths': self.lengths[index],167'input_times': input_times[0],168'target_times': target_times[0],169'target_times_mask': target_times_mask[0],170}171if self.x_offset.shape[0] > 1:172d['x_offset'] = self.x_offset[index]173d['x_scale'] = self.x_scale[index]174return d175
176def __len__(self):177return len(self.x)178