google-research
85 строк · 2.6 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16# coding=utf-8
17# Copyright 2022 The Google Research Authors.
18#
19# Licensed under the Apache License, Version 2.0 (the "License");
20# you may not use this file except in compliance with the License.
21# You may obtain a copy of the License at
22#
23# http://www.apache.org/licenses/LICENSE-2.0
24#
25# Unless required by applicable law or agreed to in writing, software
26# distributed under the License is distributed on an "AS IS" BASIS,
27# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
28# See the License for the specific language governing permissions and
29# limitations under the License.
30"""Generates Toy Data."""
31
32import argparse
33
34import numpy as np
35
36
37def generate_sine_series(length):
38"""Generates an example time series that can be used with the model.
39
40Args:
41length: The length of the time series.
42
43Returns:
44A tuple of the adjacency matrix and the array of time series.
45"""
46adjacency_matrix = np.array([
47[1.0, 0.0, 0.0, 1.0, 1.0, 0.0],
48[0.0, 1.0, 0.0, 1.0, 0.0, 1.0],
49[0.0, 0.0, 1.0, 0.0, 1.0, 1.0],
50[1.0, 1.0, 0.0, 1.0, 0.0, 0.0],
51[1.0, 0.0, 1.0, 0.0, 1.0, 0.0],
52[0.0, 1.0, 1.0, 0.0, 0.0, 1.0],
53])
54time_steps = np.arange(length)
55sine_series1 = np.sin(time_steps)
56sine_series2 = np.sin(time_steps * 2 + 1.0)
57sine_series3 = np.sin(time_steps * 3 + 2.0)
58series_array = np.stack((
59sine_series1,
60sine_series2,
61sine_series3,
62sine_series1 + sine_series2,
63sine_series1 + sine_series3,
64sine_series2 + sine_series3,
65),
66axis=1)
67
68return adjacency_matrix, series_array.reshape(length, 6, 1)
69
70
71if __name__ == "__main__":
72parser = argparse.ArgumentParser()
73# Time series length.
74parser.add_argument("-l", "--length", type=int, default=10000)
75# The path to save the data file.
76parser.add_argument(
77"-p",
78"--path",
79type=str,
80default="./editable_graph_temporal/toy_data.npz")
81
82args = parser.parse_args()
83
84adj, time_data = generate_sine_series(args.length)
85np.savez(args.path, x=time_data, adj=adj)
86