google-research
624 строки · 16.7 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Make predictions using sktime lib as baselines, and evaluate."""
17
18# allow capital letter names for dimensions to improve clarity (e.g. N, T, D)
19# pylint: disable=invalid-name
20
21import argparse
22import copy
23import datetime
24import multiprocessing as mp
25import os
26import time
27
28from data_formatting.datasets import get_favorita_data
29from data_formatting.datasets import get_m3_data
30from data_formatting.datasets import get_m3_df
31from lib.evaluator import Evaluator
32from lib.naive_scaling_baseline import NaiveScalingBaseline
33from main import get_full_test_metrics
34import numpy as np
35import pandas as pd
36from sktime.forecasting.arima import ARIMA
37from sktime.forecasting.base import ForecastingHorizon
38from sktime.forecasting.exp_smoothing import ExponentialSmoothing
39from sktime.forecasting.naive import NaiveForecaster
40from sktime.forecasting.theta import ThetaForecaster
41from sktime.transformations.series.detrend import Deseasonalizer
42import torch
43from tqdm import tqdm
44from utils import get_summary
45import wandb
46
47
48def proc_print(msg):
49print(f'pid: {os.getpid()},\t{msg}')
50
51
52def tprint(msg):
53now = datetime.datetime.now()
54proc_print(f'[{now}]\t{msg}')
55
56
57def evaluate(
58dataset_factory,
59test_preds,
60naive_model,
61lead_time,
62scale01,
63test_t_min,
64target_service_level,
65unit_holding_cost,
66unit_stockout_cost,
67unit_var_o_cost,
68valid_t_start=None,
69test_t_start=None,
70target_dims=(0,),
71parallel=False,
72num_workers=0,
73device=torch.device('cpu'),
74):
75"""Evaluate predictions.
76
77Args:
78dataset_factory: factory for datasets
79test_preds: predicted values
80naive_model: baseline model
81lead_time: number of time points it takes for inventory to come in
82scale01: whether predictions are scaled between 0 and 1
83test_t_min: first timepoint for evaluation
84target_service_level: service level to use for safety stock calculation
85unit_holding_cost: cost per unit held
86unit_stockout_cost: cost per unit stockout
87unit_var_o_cost: cost per unit order variance
88valid_t_start: start of validation time period
89test_t_start: start of test time period
90target_dims: dimensions corresponding to target
91parallel: whether to evaluate in parallel
92num_workers: number of workers for parallel dataloading
93device: device to perform computations on
94
95Returns:
96aggregate and per-series metrics
97"""
98evaluator = Evaluator(0, scale01, device, target_dims=list(target_dims))
99
100scale_by_naive_model = False
101quantile_loss = None
102use_wandb = False
103test_metrics, expanded_test_metrics = get_full_test_metrics(
104dataset_factory,
105test_preds.unsqueeze(-1),
106num_workers,
107parallel,
108device,
109test_t_min,
110valid_t_start,
111test_t_start,
112evaluator,
113target_service_level,
114lead_time,
115unit_holding_cost,
116unit_stockout_cost,
117unit_var_o_cost,
118naive_model,
119scale_by_naive_model,
120quantile_loss,
121use_wandb,
122)
123return test_metrics, expanded_test_metrics
124
125
126def get_favorita_df(impute0=True, N=None, test_t_max=None):
127"""Retrieves Favorita dataset in a format amenable to sktime lib.
128
129Args:
130impute0: whether to impute with 0's
131N: number of series to limit to. if None, no limit
132test_t_max: maximum time point to evaluate
133
134Returns:
135
136"""
137df = pd.read_csv('../data/favorita/favorita_df.csv').set_index('traj_id')
138if impute0:
139df = df.fillna(0)
140
141if N is not None:
142df = df[:N]
143
144df['N'] = df.shape[1]
145Ns = df['N']
146if test_t_max is not None:
147Ns = df['N'].clip(upper=test_t_max)
148series = df.drop('N', axis=1)
149series.columns = list(range(1, len(series.columns) + 1))
150
151if test_t_max is not None:
152series = series.iloc[:, :test_t_max]
153
154df = pd.concat([Ns, series], axis=1)
155print('Length of Favorita dataset: ', len(df))
156return df
157
158
159class DeseasonalizedForecaster:
160"""Wrapper around sktime forecasters that first de-seasonalizes the data."""
161
162def __init__(self, fc, sp=1, model='additive'):
163self.fc = fc
164self.deseasonalizer = Deseasonalizer(sp=sp, model=model)
165
166def fit(self, y):
167self.deseasonalizer.fit(y)
168y_new = self.deseasonalizer.transform(y)
169self.fc.fit(y_new)
170
171def reset(self):
172self.fc.reset()
173
174def predict(self, horizon):
175preds = self.fc.predict(horizon)
176preds = self.deseasonalizer.inverse_transform(preds)
177return preds
178
179def __str__(self):
180return 'Deseasonalized' + str(self.fc)
181
182
183def predict_series(
184fc, row, series_len, first_cutoff, lead_time, test_t_max, test_t_min, i
185):
186"""Rolling forward in time, make predictions for the series.
187
188Args:
189fc: forecaster
190row: series
191series_len: length of series
192first_cutoff: first time point to make predictions at
193lead_time: time points it takes for inventory to come in
194test_t_max: maximum time point to evaluate
195test_t_min: minimum time point to evaluate
196i: index of series
197
198Returns:
199predictions for the series (numpy array)
200"""
201print('predicting for series: ', i)
202y = row[list(range(1, series_len + 1))]
203y.index = y.index.astype(int)
204series_preds = np.ones((test_t_max - test_t_min, lead_time))
205for cutoff in range(first_cutoff, series_len):
206t = cutoff + 1 # current timepoint
207tr_ids = list(range(1, t))
208te_ids = list(range(t, min(t + lead_time, series_len + 1)))
209horizon = ForecastingHorizon(np.array(te_ids), is_relative=False)
210y_tr = y[y.index.isin(tr_ids)]
211
212# fit and make predictions
213fc.reset()
214fc.fit(y_tr)
215
216preds = fc.predict(horizon)
217series_preds[cutoff - first_cutoff, : len(preds)] = preds
218print('finished predicting series: ', i)
219return series_preds
220
221
222def predict_roll_forward(
223fc,
224fc_name,
225df,
226folder,
227dataset_name,
228start,
229idx_range,
230test_t_max,
231test_t_min,
232lead_time,
233first_cutoff,
234pool=None,
235):
236"""Make predictions for entire dataframe, rolling forward in time.
237
238Saves predictions in folder. If previously saved, simply re-loads predictions.
239
240Args:
241fc: forecaster
242fc_name: forecaster name
243df: dataframe containing (univariate) data to make predictions for
244folder: folder to save predictions
245dataset_name: name of dataset
246start: starting unix time (to calculate runtime)
247idx_range: range of indices of series to consider
248test_t_max: maximum timepoint to evaluate
249test_t_min: minimum timepoint to evaluate
250lead_time: number of timepoints it takes for inventory to come in
251first_cutoff: first time point to make predictions for
252pool: multiprocessing pool, if available (else None)
253
254Returns:
255predictions for entire dataframe, across all time
256"""
257fpath = os.path.join(folder, f'{dataset_name}_{fc_name}_N{len(df)}.npy')
258print(fpath)
259if os.path.exists(fpath):
260test_preds = np.load(fpath)
261print('loaded predictions from: ', fpath)
262else:
263print(fc_name, ' time: ', time.time() - start)
264print(idx_range)
265all_series_preds = []
266for i in tqdm(range(idx_range[0], idx_range[1])):
267row = df.iloc[i, :]
268series_len = int(row['N'])
269if pool is not None:
270all_series_preds.append((
271i,
272pool.apply_async(
273predict_series,
274[
275copy.deepcopy(fc),
276row,
277series_len,
278first_cutoff,
279lead_time,
280test_t_max,
281test_t_min,
282i,
283],
284),
285))
286else:
287all_series_preds.append((
288i,
289predict_series(
290fc,
291row,
292series_len,
293first_cutoff,
294lead_time,
295test_t_max,
296test_t_min,
297i,
298),
299))
300
301if pool is not None:
302for i, series_preds in all_series_preds:
303while not series_preds.ready():
304print('waiting for series: ', i)
305time.sleep(100)
306
307test_preds = (
308np.ones(
309(idx_range[1] - idx_range[0], test_t_max - test_t_min, lead_time)
310)
311* 1e18
312)
313for i, series_preds in all_series_preds:
314if pool is not None:
315series_preds = series_preds.get()
316test_preds[i - idx_range[0], :] = series_preds
317np.save(fpath, test_preds, allow_pickle=False)
318print('finished predicting roll-forward: ', time.time() - start)
319return test_preds
320
321
322def get_sktime_predictions(
323df,
324dataset_name,
325forecasters,
326idx_range,
327test_t_max,
328test_t_min,
329lead_time,
330first_cutoff,
331folder,
332parallel=False,
333):
334"""Get predictions for list of forecasters of interest.
335
336Args:
337df: dataset
338dataset_name: name of dataset
339forecasters: list of forecasters of interest
340idx_range: range of indices to evaluate
341test_t_max: maximum timepoint to evaluate
342test_t_min: minimum timepoint to evaluate
343lead_time: amount of timepoints it takes for inventory to come in
344first_cutoff: first timepoint to make predictions for
345folder: folder to save predictions
346parallel: whether to make predictions in parallel
347
348Returns:
349dictionary mapping forecasters to predictions
350"""
351start = time.time()
352pool = None
353if parallel:
354num_proc = int(mp.cpu_count())
355pool = mp.Pool(num_proc)
356print('Number of processors: ', num_proc)
357
358fc_to_preds = {}
359for fc_name, fc in forecasters.items():
360test_preds = predict_roll_forward(
361fc,
362fc_name,
363df,
364folder,
365dataset_name,
366start,
367idx_range,
368test_t_max,
369test_t_min,
370lead_time,
371first_cutoff,
372pool,
373)
374fc_to_preds[fc_name] = test_preds
375
376if parallel:
377pool.close()
378pool.join()
379return fc_to_preds
380
381
382def main():
383mp.set_start_method('spawn')
384parser = argparse.ArgumentParser()
385parser.add_argument('--parallel', action='store_true')
386parser.add_argument('--dataset', choices=['m3', 'favorita'])
387parser.add_argument(
388'--forecasters',
389choices=[
390'NaiveForecaster',
391'ExponentialSmoothing',
392'ThetaForecaster',
393'ARIMA',
394'DeseasonalizedThetaForecaster',
395],
396action='append',
397)
398parser.add_argument('--N', type=int, default=None)
399parser.add_argument('--preds_only', action='store_true')
400parser.add_argument('--num_workers', type=int, default=0)
401
402args = parser.parse_args()
403
404wandb_log = True
405parallel = args.parallel
406num_workers = args.num_workers
407dataset_name = args.dataset
408
409if dataset_name == 'm3':
410data_fpath = '../data/m3/m3_industry_monthly_shuffled.csv'
411df = get_m3_df(N=args.N, csv_fpath=data_fpath, idx_range=None)
412Ns = df['N']
413series = df.drop('N', axis=1)
414series.columns = series.columns.astype(int)
415df = pd.concat([Ns, series], axis=1)
416
417idx_range = (20, len(df))
418
419test_t_min = 36
420test_t_max = 144
421valid_t_start = 72
422test_t_start = 108
423
424forecasting_horizon = 12
425input_window_size = 24
426
427lead_time = 6
428scale01 = False
429target_service_level = 0.95
430N = args.N
431periodicity = 12
432first_cutoff = test_t_min
433
434dataset_factory = get_m3_data(
435forecasting_horizon=forecasting_horizon,
436minmax_scaling=scale01,
437train_prop=None,
438val_prop=None,
439batch_size=None,
440input_window_size=input_window_size,
441csv_fpath=data_fpath,
442default_nan_value=1e15,
443rolling_evaluation=True,
444idx_range=idx_range,
445N=N,
446)
447
448unit_costs = [
449(1, 1, 1e-06),
450(1, 1, 1e-05),
451(1, 2, 1e-06),
452(1, 2, 1e-05),
453(1, 10, 1e-06),
454(1, 10, 1e-05),
455(2, 1, 1e-06),
456(2, 1, 1e-05),
457(2, 2, 1e-06),
458(2, 2, 1e-05),
459(2, 10, 1e-06),
460(2, 10, 1e-05),
461(10, 1, 1e-06),
462(10, 1, 1e-05),
463(10, 2, 1e-06),
464(10, 2, 1e-05),
465(10, 10, 1e-06),
466(10, 10, 1e-05),
467]
468else:
469assert dataset_name == 'favorita'
470test_t_max = 396
471valid_t_start = 334
472test_t_start = 364
473test_t_min = 180
474forecasting_horizon = 30
475input_window_size = 90
476first_cutoff = test_t_min
477
478df = get_favorita_df(impute0=True, N=args.N, test_t_max=test_t_max)
479idx_range = (0, len(df))
480
481lead_time = 7
482scale01 = False
483N = args.N
484target_service_level = 0.95
485periodicity = 7
486dataset_factory = get_favorita_data(
487forecasting_horizon=forecasting_horizon,
488minmax_scaling=scale01,
489input_window_size=input_window_size,
490data_fpath='../data/favorita/favorita_tensor_full.npy',
491default_nan_value=1e15,
492rolling_evaluation=True,
493N=N,
494test_t_max=test_t_max,
495)
496
497unit_costs = [
498(1, 1, 1e-02),
499(1, 1, 1e-03),
500(1, 2, 1e-02),
501(1, 2, 1e-03),
502(1, 10, 1e-02),
503(1, 10, 1e-03),
504(2, 1, 1e-02),
505(2, 1, 1e-03),
506(2, 2, 1e-02),
507(2, 2, 1e-03),
508(2, 10, 1e-02),
509(2, 10, 1e-03),
510(10, 1, 1e-02),
511(10, 1, 1e-03),
512(10, 2, 1e-02),
513(10, 2, 1e-03),
514(10, 10, 1e-02),
515(10, 10, 1e-03),
516]
517
518all_forecasters = {
519'NaiveForecaster': NaiveForecaster(sp=periodicity),
520'ExponentialSmoothing': ExponentialSmoothing(
521trend='add', seasonal='add', sp=periodicity
522),
523'DeseasonalizedThetaForecaster': DeseasonalizedForecaster(
524ThetaForecaster(deseasonalize=False), sp=periodicity
525),
526'ARIMA': ARIMA(),
527}
528
529forecasters = {k: all_forecasters[k] for k in args.forecasters}
530folder = 'sktime_predictions_seasonal/'
531if not os.path.exists(folder):
532os.makedirs(folder)
533fc_to_preds = get_sktime_predictions(
534df,
535dataset_name,
536forecasters,
537idx_range,
538test_t_max,
539test_t_min,
540lead_time,
541first_cutoff,
542folder,
543parallel=parallel,
544)
545
546if args.preds_only:
547return
548
549tprint('Making evaluations...')
550start = time.time()
551naive_model = NaiveScalingBaseline(
552forecasting_horizon=lead_time, init_alpha=1.0, periodicity=12, frozen=True
553)
554for fc_name, test_preds in fc_to_preds.items():
555test_preds = torch.from_numpy(test_preds)
556for unit_cost in unit_costs:
557print(unit_cost)
558unit_holding_cost, unit_stockout_cost, unit_var_o_cost = unit_cost
559test_metrics, expanded_test_metrics = evaluate(
560dataset_factory,
561test_preds,
562naive_model,
563lead_time,
564scale01,
565test_t_min,
566target_service_level,
567unit_holding_cost,
568unit_stockout_cost,
569unit_var_o_cost,
570valid_t_start=valid_t_start,
571test_t_start=test_t_start,
572parallel=parallel,
573num_workers=num_workers,
574)
575print('getting summary...')
576test_results = get_summary(
577test_metrics=test_metrics,
578model_name=fc_name,
579optimization_obj='None',
580max_steps='None',
581start=start,
582unit_holding_cost=unit_holding_cost,
583unit_stockout_cost=unit_stockout_cost,
584unit_var_o_cost=unit_var_o_cost,
585valid_t_start=valid_t_start,
586learned_alpha=None,
587quantile_loss=None,
588naive_model=naive_model,
589use_wandb=False,
590expanded_test_metrics=expanded_test_metrics,
591idx_range=None,
592)
593
594summary = test_results['summary']
595now = datetime.datetime.now()
596now = now.strftime('%m-%d-%Y-%H:%M:%S')
597tags = ['sktime']
598tag_str = ''.join(tags)
599if wandb_log:
600wandb.init(
601name=f'{tag_str}_{now}_{fc_name}_summary',
602project='sktime-seasonal-summaries',
603reinit=True,
604tags=tags,
605config={
606'model_name': fc_name,
607'unit_holding_cost': unit_holding_cost,
608'unit_stockout_cost': unit_stockout_cost,
609'unit_var_o_cost': unit_var_o_cost,
610'dataset_name': args.dataset,
611},
612)
613wandb.log(
614{
615'combined_test_perfs': wandb.Table(
616dataframe=pd.DataFrame([summary])
617)
618}
619)
620wandb.finish()
621
622
623if __name__ == '__main__':
624main()
625