google-research

Форк
0
624 строки · 16.7 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Make predictions using sktime lib as baselines, and evaluate."""
17

18
# allow capital letter names for dimensions to improve clarity (e.g. N, T, D)
19
# pylint: disable=invalid-name
20

21
import argparse
22
import copy
23
import datetime
24
import multiprocessing as mp
25
import os
26
import time
27

28
from data_formatting.datasets import get_favorita_data
29
from data_formatting.datasets import get_m3_data
30
from data_formatting.datasets import get_m3_df
31
from lib.evaluator import Evaluator
32
from lib.naive_scaling_baseline import NaiveScalingBaseline
33
from main import get_full_test_metrics
34
import numpy as np
35
import pandas as pd
36
from sktime.forecasting.arima import ARIMA
37
from sktime.forecasting.base import ForecastingHorizon
38
from sktime.forecasting.exp_smoothing import ExponentialSmoothing
39
from sktime.forecasting.naive import NaiveForecaster
40
from sktime.forecasting.theta import ThetaForecaster
41
from sktime.transformations.series.detrend import Deseasonalizer
42
import torch
43
from tqdm import tqdm
44
from utils import get_summary
45
import wandb
46

47

48
def proc_print(msg):
49
  print(f'pid: {os.getpid()},\t{msg}')
50

51

52
def tprint(msg):
53
  now = datetime.datetime.now()
54
  proc_print(f'[{now}]\t{msg}')
55

56

57
def evaluate(
58
    dataset_factory,
59
    test_preds,
60
    naive_model,
61
    lead_time,
62
    scale01,
63
    test_t_min,
64
    target_service_level,
65
    unit_holding_cost,
66
    unit_stockout_cost,
67
    unit_var_o_cost,
68
    valid_t_start=None,
69
    test_t_start=None,
70
    target_dims=(0,),
71
    parallel=False,
72
    num_workers=0,
73
    device=torch.device('cpu'),
74
):
75
  """Evaluate predictions.
76

77
  Args:
78
    dataset_factory: factory for datasets
79
    test_preds: predicted values
80
    naive_model: baseline model
81
    lead_time: number of time points it takes for inventory to come in
82
    scale01: whether predictions are scaled between 0 and 1
83
    test_t_min: first timepoint for evaluation
84
    target_service_level: service level to use for safety stock calculation
85
    unit_holding_cost: cost per unit held
86
    unit_stockout_cost: cost per unit stockout
87
    unit_var_o_cost: cost per unit order variance
88
    valid_t_start: start of validation time period
89
    test_t_start: start of test time period
90
    target_dims: dimensions corresponding to target
91
    parallel:  whether to evaluate in parallel
92
    num_workers: number of workers for parallel dataloading
93
    device: device to perform computations on
94

95
  Returns:
96
    aggregate and per-series metrics
97
  """
98
  evaluator = Evaluator(0, scale01, device, target_dims=list(target_dims))
99

100
  scale_by_naive_model = False
101
  quantile_loss = None
102
  use_wandb = False
103
  test_metrics, expanded_test_metrics = get_full_test_metrics(
104
      dataset_factory,
105
      test_preds.unsqueeze(-1),
106
      num_workers,
107
      parallel,
108
      device,
109
      test_t_min,
110
      valid_t_start,
111
      test_t_start,
112
      evaluator,
113
      target_service_level,
114
      lead_time,
115
      unit_holding_cost,
116
      unit_stockout_cost,
117
      unit_var_o_cost,
118
      naive_model,
119
      scale_by_naive_model,
120
      quantile_loss,
121
      use_wandb,
122
  )
123
  return test_metrics, expanded_test_metrics
124

125

126
def get_favorita_df(impute0=True, N=None, test_t_max=None):
127
  """Retrieves Favorita dataset in a format amenable to sktime lib.
128

129
  Args:
130
    impute0: whether to impute with 0's
131
    N: number of series to limit to. if None, no limit
132
    test_t_max: maximum time point to evaluate
133

134
  Returns:
135

136
  """
137
  df = pd.read_csv('../data/favorita/favorita_df.csv').set_index('traj_id')
138
  if impute0:
139
    df = df.fillna(0)
140

141
  if N is not None:
142
    df = df[:N]
143

144
  df['N'] = df.shape[1]
145
  Ns = df['N']
146
  if test_t_max is not None:
147
    Ns = df['N'].clip(upper=test_t_max)
148
  series = df.drop('N', axis=1)
149
  series.columns = list(range(1, len(series.columns) + 1))
150

151
  if test_t_max is not None:
152
    series = series.iloc[:, :test_t_max]
153

154
  df = pd.concat([Ns, series], axis=1)
155
  print('Length of Favorita dataset: ', len(df))
156
  return df
157

158

159
class DeseasonalizedForecaster:
160
  """Wrapper around sktime forecasters that first de-seasonalizes the data."""
161

162
  def __init__(self, fc, sp=1, model='additive'):
163
    self.fc = fc
164
    self.deseasonalizer = Deseasonalizer(sp=sp, model=model)
165

166
  def fit(self, y):
167
    self.deseasonalizer.fit(y)
168
    y_new = self.deseasonalizer.transform(y)
169
    self.fc.fit(y_new)
170

171
  def reset(self):
172
    self.fc.reset()
173

174
  def predict(self, horizon):
175
    preds = self.fc.predict(horizon)
176
    preds = self.deseasonalizer.inverse_transform(preds)
177
    return preds
178

179
  def __str__(self):
180
    return 'Deseasonalized' + str(self.fc)
181

182

183
def predict_series(
184
    fc, row, series_len, first_cutoff, lead_time, test_t_max, test_t_min, i
185
):
186
  """Rolling forward in time, make predictions for the series.
187

188
  Args:
189
    fc: forecaster
190
    row: series
191
    series_len: length of series
192
    first_cutoff: first time point to make predictions at
193
    lead_time: time points it takes for inventory to come in
194
    test_t_max: maximum time point to evaluate
195
    test_t_min: minimum time point to evaluate
196
    i: index of series
197

198
  Returns:
199
    predictions for the series (numpy array)
200
  """
201
  print('predicting for series: ', i)
202
  y = row[list(range(1, series_len + 1))]
203
  y.index = y.index.astype(int)
204
  series_preds = np.ones((test_t_max - test_t_min, lead_time))
205
  for cutoff in range(first_cutoff, series_len):
206
    t = cutoff + 1  # current timepoint
207
    tr_ids = list(range(1, t))
208
    te_ids = list(range(t, min(t + lead_time, series_len + 1)))
209
    horizon = ForecastingHorizon(np.array(te_ids), is_relative=False)
210
    y_tr = y[y.index.isin(tr_ids)]
211

212
    # fit and make predictions
213
    fc.reset()
214
    fc.fit(y_tr)
215

216
    preds = fc.predict(horizon)
217
    series_preds[cutoff - first_cutoff, : len(preds)] = preds
218
  print('finished predicting series: ', i)
219
  return series_preds
220

221

222
def predict_roll_forward(
223
    fc,
224
    fc_name,
225
    df,
226
    folder,
227
    dataset_name,
228
    start,
229
    idx_range,
230
    test_t_max,
231
    test_t_min,
232
    lead_time,
233
    first_cutoff,
234
    pool=None,
235
):
236
  """Make predictions for entire dataframe, rolling forward in time.
237

238
  Saves predictions in folder. If previously saved, simply re-loads predictions.
239

240
  Args:
241
    fc: forecaster
242
    fc_name: forecaster name
243
    df: dataframe containing (univariate) data to make predictions for
244
    folder: folder to save predictions
245
    dataset_name: name of dataset
246
    start: starting unix time (to calculate runtime)
247
    idx_range: range of indices of series to consider
248
    test_t_max: maximum timepoint to evaluate
249
    test_t_min: minimum timepoint to evaluate
250
    lead_time: number of timepoints it takes for inventory to come in
251
    first_cutoff: first time point to make predictions for
252
    pool: multiprocessing pool, if available (else None)
253

254
  Returns:
255
    predictions for entire dataframe, across all time
256
  """
257
  fpath = os.path.join(folder, f'{dataset_name}_{fc_name}_N{len(df)}.npy')
258
  print(fpath)
259
  if os.path.exists(fpath):
260
    test_preds = np.load(fpath)
261
    print('loaded predictions from: ', fpath)
262
  else:
263
    print(fc_name, ' time: ', time.time() - start)
264
    print(idx_range)
265
    all_series_preds = []
266
    for i in tqdm(range(idx_range[0], idx_range[1])):
267
      row = df.iloc[i, :]
268
      series_len = int(row['N'])
269
      if pool is not None:
270
        all_series_preds.append((
271
            i,
272
            pool.apply_async(
273
                predict_series,
274
                [
275
                    copy.deepcopy(fc),
276
                    row,
277
                    series_len,
278
                    first_cutoff,
279
                    lead_time,
280
                    test_t_max,
281
                    test_t_min,
282
                    i,
283
                ],
284
            ),
285
        ))
286
      else:
287
        all_series_preds.append((
288
            i,
289
            predict_series(
290
                fc,
291
                row,
292
                series_len,
293
                first_cutoff,
294
                lead_time,
295
                test_t_max,
296
                test_t_min,
297
                i,
298
            ),
299
        ))
300

301
    if pool is not None:
302
      for i, series_preds in all_series_preds:
303
        while not series_preds.ready():
304
          print('waiting for series: ', i)
305
          time.sleep(100)
306

307
    test_preds = (
308
        np.ones(
309
            (idx_range[1] - idx_range[0], test_t_max - test_t_min, lead_time)
310
        )
311
        * 1e18
312
    )
313
    for i, series_preds in all_series_preds:
314
      if pool is not None:
315
        series_preds = series_preds.get()
316
      test_preds[i - idx_range[0], :] = series_preds
317
    np.save(fpath, test_preds, allow_pickle=False)
318
  print('finished predicting roll-forward: ', time.time() - start)
319
  return test_preds
320

321

322
def get_sktime_predictions(
323
    df,
324
    dataset_name,
325
    forecasters,
326
    idx_range,
327
    test_t_max,
328
    test_t_min,
329
    lead_time,
330
    first_cutoff,
331
    folder,
332
    parallel=False,
333
):
334
  """Get predictions for list of forecasters of interest.
335

336
  Args:
337
    df: dataset
338
    dataset_name: name of dataset
339
    forecasters: list of forecasters of interest
340
    idx_range: range of indices to evaluate
341
    test_t_max: maximum timepoint to evaluate
342
    test_t_min: minimum timepoint to evaluate
343
    lead_time: amount of timepoints it takes for inventory to come in
344
    first_cutoff: first timepoint to make predictions for
345
    folder: folder to save predictions
346
    parallel: whether to make predictions in parallel
347

348
  Returns:
349
    dictionary mapping forecasters to predictions
350
  """
351
  start = time.time()
352
  pool = None
353
  if parallel:
354
    num_proc = int(mp.cpu_count())
355
    pool = mp.Pool(num_proc)
356
    print('Number of processors: ', num_proc)
357

358
  fc_to_preds = {}
359
  for fc_name, fc in forecasters.items():
360
    test_preds = predict_roll_forward(
361
        fc,
362
        fc_name,
363
        df,
364
        folder,
365
        dataset_name,
366
        start,
367
        idx_range,
368
        test_t_max,
369
        test_t_min,
370
        lead_time,
371
        first_cutoff,
372
        pool,
373
    )
374
    fc_to_preds[fc_name] = test_preds
375

376
  if parallel:
377
    pool.close()
378
    pool.join()
379
  return fc_to_preds
380

381

382
def main():
383
  mp.set_start_method('spawn')
384
  parser = argparse.ArgumentParser()
385
  parser.add_argument('--parallel', action='store_true')
386
  parser.add_argument('--dataset', choices=['m3', 'favorita'])
387
  parser.add_argument(
388
      '--forecasters',
389
      choices=[
390
          'NaiveForecaster',
391
          'ExponentialSmoothing',
392
          'ThetaForecaster',
393
          'ARIMA',
394
          'DeseasonalizedThetaForecaster',
395
      ],
396
      action='append',
397
  )
398
  parser.add_argument('--N', type=int, default=None)
399
  parser.add_argument('--preds_only', action='store_true')
400
  parser.add_argument('--num_workers', type=int, default=0)
401

402
  args = parser.parse_args()
403

404
  wandb_log = True
405
  parallel = args.parallel
406
  num_workers = args.num_workers
407
  dataset_name = args.dataset
408

409
  if dataset_name == 'm3':
410
    data_fpath = '../data/m3/m3_industry_monthly_shuffled.csv'
411
    df = get_m3_df(N=args.N, csv_fpath=data_fpath, idx_range=None)
412
    Ns = df['N']
413
    series = df.drop('N', axis=1)
414
    series.columns = series.columns.astype(int)
415
    df = pd.concat([Ns, series], axis=1)
416

417
    idx_range = (20, len(df))
418

419
    test_t_min = 36
420
    test_t_max = 144
421
    valid_t_start = 72
422
    test_t_start = 108
423

424
    forecasting_horizon = 12
425
    input_window_size = 24
426

427
    lead_time = 6
428
    scale01 = False
429
    target_service_level = 0.95
430
    N = args.N
431
    periodicity = 12
432
    first_cutoff = test_t_min
433

434
    dataset_factory = get_m3_data(
435
        forecasting_horizon=forecasting_horizon,
436
        minmax_scaling=scale01,
437
        train_prop=None,
438
        val_prop=None,
439
        batch_size=None,
440
        input_window_size=input_window_size,
441
        csv_fpath=data_fpath,
442
        default_nan_value=1e15,
443
        rolling_evaluation=True,
444
        idx_range=idx_range,
445
        N=N,
446
    )
447

448
    unit_costs = [
449
        (1, 1, 1e-06),
450
        (1, 1, 1e-05),
451
        (1, 2, 1e-06),
452
        (1, 2, 1e-05),
453
        (1, 10, 1e-06),
454
        (1, 10, 1e-05),
455
        (2, 1, 1e-06),
456
        (2, 1, 1e-05),
457
        (2, 2, 1e-06),
458
        (2, 2, 1e-05),
459
        (2, 10, 1e-06),
460
        (2, 10, 1e-05),
461
        (10, 1, 1e-06),
462
        (10, 1, 1e-05),
463
        (10, 2, 1e-06),
464
        (10, 2, 1e-05),
465
        (10, 10, 1e-06),
466
        (10, 10, 1e-05),
467
    ]
468
  else:
469
    assert dataset_name == 'favorita'
470
    test_t_max = 396
471
    valid_t_start = 334
472
    test_t_start = 364
473
    test_t_min = 180
474
    forecasting_horizon = 30
475
    input_window_size = 90
476
    first_cutoff = test_t_min
477

478
    df = get_favorita_df(impute0=True, N=args.N, test_t_max=test_t_max)
479
    idx_range = (0, len(df))
480

481
    lead_time = 7
482
    scale01 = False
483
    N = args.N
484
    target_service_level = 0.95
485
    periodicity = 7
486
    dataset_factory = get_favorita_data(
487
        forecasting_horizon=forecasting_horizon,
488
        minmax_scaling=scale01,
489
        input_window_size=input_window_size,
490
        data_fpath='../data/favorita/favorita_tensor_full.npy',
491
        default_nan_value=1e15,
492
        rolling_evaluation=True,
493
        N=N,
494
        test_t_max=test_t_max,
495
    )
496

497
    unit_costs = [
498
        (1, 1, 1e-02),
499
        (1, 1, 1e-03),
500
        (1, 2, 1e-02),
501
        (1, 2, 1e-03),
502
        (1, 10, 1e-02),
503
        (1, 10, 1e-03),
504
        (2, 1, 1e-02),
505
        (2, 1, 1e-03),
506
        (2, 2, 1e-02),
507
        (2, 2, 1e-03),
508
        (2, 10, 1e-02),
509
        (2, 10, 1e-03),
510
        (10, 1, 1e-02),
511
        (10, 1, 1e-03),
512
        (10, 2, 1e-02),
513
        (10, 2, 1e-03),
514
        (10, 10, 1e-02),
515
        (10, 10, 1e-03),
516
    ]
517

518
  all_forecasters = {
519
      'NaiveForecaster': NaiveForecaster(sp=periodicity),
520
      'ExponentialSmoothing': ExponentialSmoothing(
521
          trend='add', seasonal='add', sp=periodicity
522
      ),
523
      'DeseasonalizedThetaForecaster': DeseasonalizedForecaster(
524
          ThetaForecaster(deseasonalize=False), sp=periodicity
525
      ),
526
      'ARIMA': ARIMA(),
527
  }
528

529
  forecasters = {k: all_forecasters[k] for k in args.forecasters}
530
  folder = 'sktime_predictions_seasonal/'
531
  if not os.path.exists(folder):
532
    os.makedirs(folder)
533
  fc_to_preds = get_sktime_predictions(
534
      df,
535
      dataset_name,
536
      forecasters,
537
      idx_range,
538
      test_t_max,
539
      test_t_min,
540
      lead_time,
541
      first_cutoff,
542
      folder,
543
      parallel=parallel,
544
  )
545

546
  if args.preds_only:
547
    return
548

549
  tprint('Making evaluations...')
550
  start = time.time()
551
  naive_model = NaiveScalingBaseline(
552
      forecasting_horizon=lead_time, init_alpha=1.0, periodicity=12, frozen=True
553
  )
554
  for fc_name, test_preds in fc_to_preds.items():
555
    test_preds = torch.from_numpy(test_preds)
556
    for unit_cost in unit_costs:
557
      print(unit_cost)
558
      unit_holding_cost, unit_stockout_cost, unit_var_o_cost = unit_cost
559
      test_metrics, expanded_test_metrics = evaluate(
560
          dataset_factory,
561
          test_preds,
562
          naive_model,
563
          lead_time,
564
          scale01,
565
          test_t_min,
566
          target_service_level,
567
          unit_holding_cost,
568
          unit_stockout_cost,
569
          unit_var_o_cost,
570
          valid_t_start=valid_t_start,
571
          test_t_start=test_t_start,
572
          parallel=parallel,
573
          num_workers=num_workers,
574
      )
575
      print('getting summary...')
576
      test_results = get_summary(
577
          test_metrics=test_metrics,
578
          model_name=fc_name,
579
          optimization_obj='None',
580
          max_steps='None',
581
          start=start,
582
          unit_holding_cost=unit_holding_cost,
583
          unit_stockout_cost=unit_stockout_cost,
584
          unit_var_o_cost=unit_var_o_cost,
585
          valid_t_start=valid_t_start,
586
          learned_alpha=None,
587
          quantile_loss=None,
588
          naive_model=naive_model,
589
          use_wandb=False,
590
          expanded_test_metrics=expanded_test_metrics,
591
          idx_range=None,
592
      )
593

594
      summary = test_results['summary']
595
      now = datetime.datetime.now()
596
      now = now.strftime('%m-%d-%Y-%H:%M:%S')
597
      tags = ['sktime']
598
      tag_str = ''.join(tags)
599
      if wandb_log:
600
        wandb.init(
601
            name=f'{tag_str}_{now}_{fc_name}_summary',
602
            project='sktime-seasonal-summaries',
603
            reinit=True,
604
            tags=tags,
605
            config={
606
                'model_name': fc_name,
607
                'unit_holding_cost': unit_holding_cost,
608
                'unit_stockout_cost': unit_stockout_cost,
609
                'unit_var_o_cost': unit_var_o_cost,
610
                'dataset_name': args.dataset,
611
            },
612
        )
613
        wandb.log(
614
            {
615
                'combined_test_perfs': wandb.Table(
616
                    dataframe=pd.DataFrame([summary])
617
                )
618
            }
619
        )
620
        wandb.finish()
621

622

623
if __name__ == '__main__':
624
  main()
625

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.