google-research

Форк
0
371 строка · 11.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Given a set of predictions, computes evaluation metrics.
17

18
After model has been trained using main.py and predictions have been saved,
19
this script can be used to further evaluate the predictions under various
20
configurations (e.g. cost tradeoffs).
21
"""
22

23
# allow capital letter names for dimensions to improve clarity (e.g. N, T, D)
24
# pylint: disable=invalid-name
25

26
import argparse
27
import itertools
28
import multiprocessing
29
import os
30
import pickle
31
import pprint
32
import time
33
from data_formatting.datasets import get_favorita_data
34
from data_formatting.datasets import get_m3_data
35
from lib.evaluator import Evaluator
36
from lib.naive_scaling_baseline import NaiveScalingBaseline
37
from main import get_full_test_metrics
38
from main import get_learned_alpha
39
import numpy as np
40
import pandas as pd
41
import torch
42
from utils.log_utils import get_summary
43
from utils.log_utils import tprint
44
import wandb
45

46

47
def eval_preds(
48
    wandb_log,
49
    parallel,
50
    num_workers,
51
    dataset_name,
52
    model_name,
53
    optimization_obj,
54
    single_rollout,
55
    no_safety_stock,
56
    max_steps,
57
    N,
58
    preds_path,
59
    model_path,
60
    project_name,
61
    run_name,
62
    tags,
63
    just_convert_to_cpu,
64
    device_name,
65
    cpu_checkpt_folder,
66
    unit_holding_costs,
67
    unit_stockout_costs,
68
    unit_var_o_costs,
69
    do_scaling,
70
):
71
  """Evaluate model predictions.
72

73
  Args:
74
    wandb_log: whether to log the metrics to wandb
75
    parallel: whether to evaluate in parallel
76
    num_workers: number of workers for parallel dataloading
77
    dataset_name: name of dataset
78
    model_name: name of model
79
    optimization_obj: name of optimization obj
80
    single_rollout: whether single rollout (vs. double rollout)
81
    no_safety_stock: whether to include safety stock in order-up-to policy
82
    max_steps: num steps per timepoint per batch
83
    N: number of series
84
    preds_path: path to tensor of predictions
85
    model_path: path to model checkpoint
86
    project_name: name of project (for wandb and logging)
87
    run_name: name of run (for wandb and logging)
88
    tags: list of tags describing experiment
89
    just_convert_to_cpu: whether to move predictions from gpu to cpu
90
    device_name: device to perform computations on
91
    cpu_checkpt_folder: folder to put cpu checkpoints
92
    unit_holding_costs: list of costs per unit held
93
    unit_stockout_costs: list of costs per unit stockout
94
    unit_var_o_costs: list of costs per unit order variance
95
    do_scaling: whether to additionally scale predictions (for sktime only)
96

97
  Raises:
98
    NotImplementedError: if a preds_path with unsupported filetype is provided
99
  """
100
  print('preds path: ', preds_path)
101
  print('model path: ', model_path)
102
  print('tags: ', tags)
103
  if 'sktime' not in tags:
104
    assert f'{dataset_name}_{model_name}_{optimization_obj}' in preds_path
105

106
  target_dims = [0]
107
  if dataset_name == 'm3':
108
    test_t_min = 36
109
    test_t_max = 144
110
    valid_t_start = 72
111
    test_t_start = 108
112

113
    forecasting_horizon = 12
114
    input_window_size = 24
115

116
    lead_time = 6
117
    scale01 = True
118
    target_service_level = 0.95
119
    periodicity = 12
120
    data_fpath = '../data/m3/m3_industry_monthly_shuffled.csv'
121
    idx_range = (20, 334)
122

123
    if not just_convert_to_cpu:
124
      tprint('getting dataset factory...')
125
      dataset_factory = get_m3_data(
126
          forecasting_horizon=forecasting_horizon,
127
          minmax_scaling=scale01,
128
          input_window_size=input_window_size,
129
          csv_fpath=data_fpath,
130
          default_nan_value=1e15,
131
          rolling_evaluation=True,
132
          idx_range=idx_range,
133
          N=N,
134
      )
135
  else:
136
    assert dataset_name == 'favorita'
137
    test_t_max = 396
138
    valid_t_start = 334
139
    test_t_start = 364
140
    test_t_min = 180
141
    forecasting_horizon = 30
142
    if single_rollout:
143
      forecasting_horizon = 7
144
    input_window_size = 90
145

146
    lead_time = 7
147
    scale01 = True
148
    target_service_level = 0.95
149
    idx_range = None
150
    periodicity = 7
151
    data_fpath = '../data/favorita/favorita_tensor_full.npy'
152

153
    if not just_convert_to_cpu:
154
      tprint('getting dataset factory...')
155
      dataset_factory = get_favorita_data(
156
          forecasting_horizon=forecasting_horizon,
157
          minmax_scaling=scale01,
158
          input_window_size=input_window_size,
159
          data_fpath=data_fpath,
160
          default_nan_value=1e15,
161
          rolling_evaluation=True,
162
          N=N,
163
          test_t_max=test_t_max,
164
      )
165

166
  device = torch.device(device_name)
167
  naive_model = NaiveScalingBaseline(
168
      forecasting_horizon=lead_time,
169
      init_alpha=1.0,
170
      periodicity=periodicity,
171
      frozen=True,
172
  ).to(device)
173
  evaluator = Evaluator(
174
      0, scale01, device, target_dims, no_safety_stock=no_safety_stock
175
  )
176

177
  scale_by_naive_model = False
178
  quantile_loss = None
179
  use_wandb = True
180

181
  # Load predictions
182
  if preds_path.endswith('.pkl'):
183
    with open(preds_path, 'rb') as fin:
184
      test_preds = pickle.load(fin)
185
    if isinstance(test_preds, list):
186
      test_preds = torch.cat(test_preds, dim=1)
187
  elif preds_path.endswith('.npy'):
188
    test_preds = torch.from_numpy(np.load(preds_path))
189
    if len(test_preds.shape) == 3:
190
      test_preds = test_preds.unsqueeze(-1)
191
  elif preds_path.endswith('test_preds.pt'):
192
    test_preds = torch.load(preds_path)
193
  else:
194
    raise NotImplementedError('Unrecognized file type: ' + preds_path)
195
  test_preds = test_preds.to(device)
196
  print('shape of orig test_preds: ', test_preds.shape)
197
  test_preds = test_preds[:, :, :lead_time, :]
198
  print('shape of truncated test_preds: ', test_preds.shape)
199

200
  # Load alpha (if exists)
201
  learned_alpha = None
202
  if 'naive' in model_name and model_path is not None:
203
    checkpoint = torch.load(model_path)
204
    model_class = NaiveScalingBaseline
205
    model_args = {
206
        'forecasting_horizon': forecasting_horizon,
207
        'periodicity': periodicity,
208
        'device': device,
209
        'target_dims': target_dims,
210
    }
211
    model = model_class(**model_args)
212
    if 'cuda' in model_path:
213
      model = torch.nn.DataParallel(model)
214
    model.load_state_dict(checkpoint['model_state_dict'])
215
    learned_alpha = get_learned_alpha(
216
        model_name=model_name, per_series_models=None, model=model
217
    )
218

219
  if just_convert_to_cpu:
220
    cpu_checkpoint = {
221
        'test_preds': test_preds.cpu(),
222
        'learned_alpha': learned_alpha,
223
    }
224
    torch.save(
225
        cpu_checkpoint, os.path.join(cpu_checkpt_folder, 'cpu_checkpoint.pt')
226
    )
227
    tprint(
228
        'Saved CPU checkpoint: '
229
        + os.path.join(cpu_checkpt_folder, 'cpu_checkpoint.pt')
230
    )
231
    return
232

233
  unit_costs = list(
234
      itertools.product(
235
          unit_holding_costs,
236
          unit_stockout_costs,
237
          unit_var_o_costs,
238
      )
239
  )
240
  for unit_cost in unit_costs:
241
    unit_holding_cost, unit_stockout_cost, unit_var_o_cost = unit_cost
242
    print(f'============== {unit_cost} ==============')
243
    config = {
244
        'model_name': model_name,
245
        'unit_holding_cost': unit_holding_cost,
246
        'unit_stockout_cost': unit_stockout_cost,
247
        'unit_var_o_cost': unit_var_o_cost,
248
        'dataset_name': dataset_name,
249
        'optimization_obj': optimization_obj,
250
        'later_eval': True,
251
        'N': len(dataset_factory),
252
    }
253
    pprint.pprint(config)
254
    if wandb_log:
255
      wandb.init(
256
          name=run_name,
257
          project=project_name,
258
          reinit=True,
259
          tags=tags,
260
          config=config,
261
      )
262
    start = time.time()
263
    test_metrics, expanded_test_metrics = get_full_test_metrics(
264
        dataset_factory,
265
        test_preds,
266
        num_workers,
267
        parallel,
268
        device,
269
        test_t_min,
270
        valid_t_start,
271
        test_t_start,
272
        evaluator,
273
        target_service_level,
274
        lead_time,
275
        unit_holding_cost,
276
        unit_stockout_cost,
277
        unit_var_o_cost,
278
        naive_model,
279
        scale_by_naive_model,
280
        quantile_loss,
281
        use_wandb=False,
282
        sum_ct_metrics=True,
283
        do_scaling=do_scaling,
284
    )
285

286
    test_results = get_summary(
287
        test_metrics,
288
        model_name,
289
        optimization_obj,
290
        max_steps,
291
        start,
292
        unit_holding_cost,
293
        unit_stockout_cost,
294
        unit_var_o_cost,
295
        valid_t_start,
296
        learned_alpha,
297
        quantile_loss,
298
        naive_model,
299
        use_wandb,
300
        expanded_test_metrics,
301
        idx_range,
302
    )
303
    summary = test_results['summary']
304

305
    print('runtime: ', time.time() - start)
306

307
    if wandb_log:
308
      wandb.log(
309
          {
310
              'combined_test_perfs': wandb.Table(
311
                  dataframe=pd.DataFrame([summary])
312
              )
313
          }
314
      )
315
      wandb.finish()
316

317

318
def main():
319
  multiprocessing.set_start_method('spawn')
320
  parser = argparse.ArgumentParser()
321
  parser.add_argument('--parallel', action='store_true')
322
  parser.add_argument('--dataset_name', choices=['m3', 'favorita'])
323
  parser.add_argument('--model_name', type=str)
324
  parser.add_argument('--optimization_obj', type=str)
325
  parser.add_argument('--max_steps', type=int)
326
  parser.add_argument('--N', type=int, default=None)
327
  parser.add_argument('--num_workers', type=int, default=0)
328
  parser.add_argument('--preds_path', type=str)
329
  parser.add_argument('--model_path', type=str, default=None)
330
  parser.add_argument('--project_name', type=str)
331
  parser.add_argument('--run_name', type=str)
332
  parser.add_argument('--tags', type=str, action='append')
333
  parser.add_argument('--unit_holding_costs', type=int, action='append')
334
  parser.add_argument('--unit_stockout_costs', type=int, action='append')
335
  parser.add_argument('--unit_var_o_costs', type=float, action='append')
336
  parser.add_argument('--single_rollout', action='store_true')
337
  parser.add_argument('--no_safety_stock', action='store_true')
338
  parser.add_argument('--device', type=str, default='cpu')
339
  parser.add_argument('--do_scaling', action='store_true')
340
  parser.add_argument('--just_convert_to_cpu', action='store_true')
341
  parser.add_argument('--cpu_checkpt_folder', type=str, default='./')
342

343
  args = parser.parse_args()
344
  eval_preds(
345
      wandb_log=True,
346
      parallel=args.parallel,
347
      num_workers=args.num_workers,
348
      dataset_name=args.dataset_name,
349
      model_name=args.model_name,
350
      optimization_obj=args.optimization_obj,
351
      single_rollout=args.single_rollout,
352
      no_safety_stock=args.no_safety_stock,
353
      max_steps=args.max_steps,
354
      N=args.N,
355
      preds_path=args.preds_path,
356
      model_path=args.model_path,
357
      project_name=args.project_name,
358
      run_name=args.run_name,
359
      tags=args.tags,
360
      just_convert_to_cpu=args.just_convert_to_cpu,
361
      device_name=args.device,
362
      cpu_checkpt_folder=args.cpu_checkpt_folder,
363
      unit_holding_costs=args.unit_holding_costs,
364
      unit_stockout_costs=args.unit_stockout_costs,
365
      unit_var_o_costs=args.unit_var_o_costs,
366
      do_scaling=args.do_scaling,
367
  )
368

369

370
if __name__ == '__main__':
371
  main()
372

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.