google-research
371 строка · 11.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Given a set of predictions, computes evaluation metrics.
17
18After model has been trained using main.py and predictions have been saved,
19this script can be used to further evaluate the predictions under various
20configurations (e.g. cost tradeoffs).
21"""
22
23# allow capital letter names for dimensions to improve clarity (e.g. N, T, D)
24# pylint: disable=invalid-name
25
26import argparse27import itertools28import multiprocessing29import os30import pickle31import pprint32import time33from data_formatting.datasets import get_favorita_data34from data_formatting.datasets import get_m3_data35from lib.evaluator import Evaluator36from lib.naive_scaling_baseline import NaiveScalingBaseline37from main import get_full_test_metrics38from main import get_learned_alpha39import numpy as np40import pandas as pd41import torch42from utils.log_utils import get_summary43from utils.log_utils import tprint44import wandb45
46
47def eval_preds(48wandb_log,49parallel,50num_workers,51dataset_name,52model_name,53optimization_obj,54single_rollout,55no_safety_stock,56max_steps,57N,58preds_path,59model_path,60project_name,61run_name,62tags,63just_convert_to_cpu,64device_name,65cpu_checkpt_folder,66unit_holding_costs,67unit_stockout_costs,68unit_var_o_costs,69do_scaling,70):71"""Evaluate model predictions.72
73Args:
74wandb_log: whether to log the metrics to wandb
75parallel: whether to evaluate in parallel
76num_workers: number of workers for parallel dataloading
77dataset_name: name of dataset
78model_name: name of model
79optimization_obj: name of optimization obj
80single_rollout: whether single rollout (vs. double rollout)
81no_safety_stock: whether to include safety stock in order-up-to policy
82max_steps: num steps per timepoint per batch
83N: number of series
84preds_path: path to tensor of predictions
85model_path: path to model checkpoint
86project_name: name of project (for wandb and logging)
87run_name: name of run (for wandb and logging)
88tags: list of tags describing experiment
89just_convert_to_cpu: whether to move predictions from gpu to cpu
90device_name: device to perform computations on
91cpu_checkpt_folder: folder to put cpu checkpoints
92unit_holding_costs: list of costs per unit held
93unit_stockout_costs: list of costs per unit stockout
94unit_var_o_costs: list of costs per unit order variance
95do_scaling: whether to additionally scale predictions (for sktime only)
96
97Raises:
98NotImplementedError: if a preds_path with unsupported filetype is provided
99"""
100print('preds path: ', preds_path)101print('model path: ', model_path)102print('tags: ', tags)103if 'sktime' not in tags:104assert f'{dataset_name}_{model_name}_{optimization_obj}' in preds_path105
106target_dims = [0]107if dataset_name == 'm3':108test_t_min = 36109test_t_max = 144110valid_t_start = 72111test_t_start = 108112
113forecasting_horizon = 12114input_window_size = 24115
116lead_time = 6117scale01 = True118target_service_level = 0.95119periodicity = 12120data_fpath = '../data/m3/m3_industry_monthly_shuffled.csv'121idx_range = (20, 334)122
123if not just_convert_to_cpu:124tprint('getting dataset factory...')125dataset_factory = get_m3_data(126forecasting_horizon=forecasting_horizon,127minmax_scaling=scale01,128input_window_size=input_window_size,129csv_fpath=data_fpath,130default_nan_value=1e15,131rolling_evaluation=True,132idx_range=idx_range,133N=N,134)135else:136assert dataset_name == 'favorita'137test_t_max = 396138valid_t_start = 334139test_t_start = 364140test_t_min = 180141forecasting_horizon = 30142if single_rollout:143forecasting_horizon = 7144input_window_size = 90145
146lead_time = 7147scale01 = True148target_service_level = 0.95149idx_range = None150periodicity = 7151data_fpath = '../data/favorita/favorita_tensor_full.npy'152
153if not just_convert_to_cpu:154tprint('getting dataset factory...')155dataset_factory = get_favorita_data(156forecasting_horizon=forecasting_horizon,157minmax_scaling=scale01,158input_window_size=input_window_size,159data_fpath=data_fpath,160default_nan_value=1e15,161rolling_evaluation=True,162N=N,163test_t_max=test_t_max,164)165
166device = torch.device(device_name)167naive_model = NaiveScalingBaseline(168forecasting_horizon=lead_time,169init_alpha=1.0,170periodicity=periodicity,171frozen=True,172).to(device)173evaluator = Evaluator(1740, scale01, device, target_dims, no_safety_stock=no_safety_stock175)176
177scale_by_naive_model = False178quantile_loss = None179use_wandb = True180
181# Load predictions182if preds_path.endswith('.pkl'):183with open(preds_path, 'rb') as fin:184test_preds = pickle.load(fin)185if isinstance(test_preds, list):186test_preds = torch.cat(test_preds, dim=1)187elif preds_path.endswith('.npy'):188test_preds = torch.from_numpy(np.load(preds_path))189if len(test_preds.shape) == 3:190test_preds = test_preds.unsqueeze(-1)191elif preds_path.endswith('test_preds.pt'):192test_preds = torch.load(preds_path)193else:194raise NotImplementedError('Unrecognized file type: ' + preds_path)195test_preds = test_preds.to(device)196print('shape of orig test_preds: ', test_preds.shape)197test_preds = test_preds[:, :, :lead_time, :]198print('shape of truncated test_preds: ', test_preds.shape)199
200# Load alpha (if exists)201learned_alpha = None202if 'naive' in model_name and model_path is not None:203checkpoint = torch.load(model_path)204model_class = NaiveScalingBaseline205model_args = {206'forecasting_horizon': forecasting_horizon,207'periodicity': periodicity,208'device': device,209'target_dims': target_dims,210}211model = model_class(**model_args)212if 'cuda' in model_path:213model = torch.nn.DataParallel(model)214model.load_state_dict(checkpoint['model_state_dict'])215learned_alpha = get_learned_alpha(216model_name=model_name, per_series_models=None, model=model217)218
219if just_convert_to_cpu:220cpu_checkpoint = {221'test_preds': test_preds.cpu(),222'learned_alpha': learned_alpha,223}224torch.save(225cpu_checkpoint, os.path.join(cpu_checkpt_folder, 'cpu_checkpoint.pt')226)227tprint(228'Saved CPU checkpoint: '229+ os.path.join(cpu_checkpt_folder, 'cpu_checkpoint.pt')230)231return232
233unit_costs = list(234itertools.product(235unit_holding_costs,236unit_stockout_costs,237unit_var_o_costs,238)239)240for unit_cost in unit_costs:241unit_holding_cost, unit_stockout_cost, unit_var_o_cost = unit_cost242print(f'============== {unit_cost} ==============')243config = {244'model_name': model_name,245'unit_holding_cost': unit_holding_cost,246'unit_stockout_cost': unit_stockout_cost,247'unit_var_o_cost': unit_var_o_cost,248'dataset_name': dataset_name,249'optimization_obj': optimization_obj,250'later_eval': True,251'N': len(dataset_factory),252}253pprint.pprint(config)254if wandb_log:255wandb.init(256name=run_name,257project=project_name,258reinit=True,259tags=tags,260config=config,261)262start = time.time()263test_metrics, expanded_test_metrics = get_full_test_metrics(264dataset_factory,265test_preds,266num_workers,267parallel,268device,269test_t_min,270valid_t_start,271test_t_start,272evaluator,273target_service_level,274lead_time,275unit_holding_cost,276unit_stockout_cost,277unit_var_o_cost,278naive_model,279scale_by_naive_model,280quantile_loss,281use_wandb=False,282sum_ct_metrics=True,283do_scaling=do_scaling,284)285
286test_results = get_summary(287test_metrics,288model_name,289optimization_obj,290max_steps,291start,292unit_holding_cost,293unit_stockout_cost,294unit_var_o_cost,295valid_t_start,296learned_alpha,297quantile_loss,298naive_model,299use_wandb,300expanded_test_metrics,301idx_range,302)303summary = test_results['summary']304
305print('runtime: ', time.time() - start)306
307if wandb_log:308wandb.log(309{310'combined_test_perfs': wandb.Table(311dataframe=pd.DataFrame([summary])312)313}314)315wandb.finish()316
317
318def main():319multiprocessing.set_start_method('spawn')320parser = argparse.ArgumentParser()321parser.add_argument('--parallel', action='store_true')322parser.add_argument('--dataset_name', choices=['m3', 'favorita'])323parser.add_argument('--model_name', type=str)324parser.add_argument('--optimization_obj', type=str)325parser.add_argument('--max_steps', type=int)326parser.add_argument('--N', type=int, default=None)327parser.add_argument('--num_workers', type=int, default=0)328parser.add_argument('--preds_path', type=str)329parser.add_argument('--model_path', type=str, default=None)330parser.add_argument('--project_name', type=str)331parser.add_argument('--run_name', type=str)332parser.add_argument('--tags', type=str, action='append')333parser.add_argument('--unit_holding_costs', type=int, action='append')334parser.add_argument('--unit_stockout_costs', type=int, action='append')335parser.add_argument('--unit_var_o_costs', type=float, action='append')336parser.add_argument('--single_rollout', action='store_true')337parser.add_argument('--no_safety_stock', action='store_true')338parser.add_argument('--device', type=str, default='cpu')339parser.add_argument('--do_scaling', action='store_true')340parser.add_argument('--just_convert_to_cpu', action='store_true')341parser.add_argument('--cpu_checkpt_folder', type=str, default='./')342
343args = parser.parse_args()344eval_preds(345wandb_log=True,346parallel=args.parallel,347num_workers=args.num_workers,348dataset_name=args.dataset_name,349model_name=args.model_name,350optimization_obj=args.optimization_obj,351single_rollout=args.single_rollout,352no_safety_stock=args.no_safety_stock,353max_steps=args.max_steps,354N=args.N,355preds_path=args.preds_path,356model_path=args.model_path,357project_name=args.project_name,358run_name=args.run_name,359tags=args.tags,360just_convert_to_cpu=args.just_convert_to_cpu,361device_name=args.device,362cpu_checkpt_folder=args.cpu_checkpt_folder,363unit_holding_costs=args.unit_holding_costs,364unit_stockout_costs=args.unit_stockout_costs,365unit_var_o_costs=args.unit_var_o_costs,366do_scaling=args.do_scaling,367)368
369
370if __name__ == '__main__':371main()372