google-research
850 строк · 27.2 Кб
1# coding=utf-8
2# Copyright 2024 The Google Research Authors.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16"""Evaluator object.
17
18Computes all evaluation metrics.
19
20Evaluation metrics include forecasting and inventory performance metrics.
21"""
22
23# allow capital letter names for dimensions to improve clarity (e.g. N, T, D)
24# pylint: disable=invalid-name
25
26import numpy as np27from scipy.stats import norm28import torch29from utils.eval_utils import get_ragged_mean30from utils.eval_utils import get_ragged_sum31from utils.eval_utils import get_ragged_var32
33
34class Evaluator(object):35"""Evaluator class. Handles differentiable computation of metrics."""36
37def __init__(38self, first_cutoff, scale01, device, target_dims, no_safety_stock39):40self.first_cutoff = first_cutoff41self.scale01 = scale0142self.device = device43self.target_dims = target_dims44self.no_safety_stock = no_safety_stock45
46def _extract_target_dims(self, batch):47"""Extracts the slices corresponding to the target of interest.48
49Args:
50batch: the batch data dictionary
51
52Returns:
53the batch data dictionary, sliced to extract the target
54"""
55new_batch = batch.copy()56for k, v in batch.items():57if k in [58'x',59'x_scale',60'x_offset',61'model_inputs',62'model_targets',63'eval_inputs',64'eval_targets',65]:66if len(v.shape) == 2:67v = v[:, self.target_dims]68elif len(v.shape) == 3:69v = v[:, :, self.target_dims]70elif len(v.shape) == 4:71v = v[:, :, :, self.target_dims]72else:73raise NotImplementedError('Unexpected number of dims: ', v.shape)74new_batch[k] = v75return new_batch76
77def _rescale(self, arr, x_scale, x_offset):78"""Scale the array back up to its original values.79
80Args:
81arr: scaled array
82x_scale: scale
83x_offset: offset
84
85Returns:
86array in its original range of values
87"""
88if not x_scale.shape:89return (arr * x_scale) + x_offset90shape = arr.shape91assert shape[0] == x_scale.shape[0]92
93# repeat scale and offset to match up with arr shape94to_expand = len(shape) - len(x_scale.shape)95for _ in range(to_expand):96x_scale = x_scale.unsqueeze(-1)97x_offset = x_offset.unsqueeze(-1)98x_scale = x_scale.repeat(1, *shape[1:])99x_offset = x_offset.repeat(1, *shape[1:])100return (arr * x_scale) + x_offset101
102def _get_lengths_from_time_mask(self, time_mask):103forecast_horizon_lengths = (104time_mask[:, :, :, :, -1].sum(dim=2).unsqueeze(2).type(torch.int64)105) # D is always last106time_lengths = (107(time_mask[:, :, 0, :, -1] > 0)108.float()109.sum(dim=1)110.unsqueeze(1)111.type(torch.int64)112)113return forecast_horizon_lengths, time_lengths114
115def compute_mse(116self,117preds,118unfolded_actual_imputed,119forecast_horizon_lengths,120time_lengths,121series_mean=True,122):123"""Compute the mean squared error, taking sequence lengths into account.124
125Args:
126preds: predictions tensor
127unfolded_actual_imputed: actual values tensor, unfolded to be the same
128shape as predictions, and imputed to avoid issues with autodiff
129forecast_horizon_lengths: lengths of each forecast horizon
130time_lengths: lengths of each series
131series_mean: whether to take the mean across series
132
133Returns:
134mean squared error
135"""
136squared_errs = (preds - unfolded_actual_imputed) ** 2 # N x T x L x D137
138# handle first cutoff139squared_errs = squared_errs[:, self.first_cutoff :, :]140forecast_horizon_lengths = forecast_horizon_lengths[141:, self.first_cutoff :, :142]143time_lengths = time_lengths - self.first_cutoff144
145# get average along forecasting horizon146mse = get_ragged_mean(147squared_errs, lens=forecast_horizon_lengths, axis=-2, device=self.device148)149# get average along time150mse = get_ragged_mean(mse, lens=time_lengths, axis=-2, device=self.device)151# get average across all series152if series_mean:153mse = mse.mean()154return mse155
156def _get_std_e(157self,158preds,159unfolded_actual_imputed,160unfolded_time_mask,161eps=1e-5,162):163"""Compute the standard deviation over previous forecast errors.164
165Args:
166preds: predictions tensor (N x T x L x D)
167unfolded_actual_imputed: actual values tensor, unfolded to be the same
168shape as predictions, and imputed to avoid issues with autodiff
169unfolded_time_mask: times mask, the [:,:,:,t] slice corresponds to whether
170the corresponding timepoint has passed
171eps: small constant for stability
172
173Returns:
174tensor of standard deviations
175"""
176N, T, _, _ = preds.shape177
178squared_errs = (preds - unfolded_actual_imputed) ** 2179squared_errs = squared_errs.unsqueeze(-1).repeat(1, 1, 1, 1, T)180masked_errs = squared_errs * unfolded_time_mask181
182# handle first cutoff183masked_errs = masked_errs[:, :, :, :, self.first_cutoff :]184mask = unfolded_time_mask[:, :, :, :, self.first_cutoff :]185mask_denom_nonzero = mask.sum(2).sum(1) # takes errors per timestep186mask_denom_nonzero = (mask_denom_nonzero != 0).float()187mask_denom_nonzero = mask.sum(2).sum(1) + (1881 - mask_denom_nonzero189) # fills in a 1 wherever it's 0190
191avg_per_time = masked_errs.sum(2).sum(1) / mask_denom_nonzero # N x T192avg_per_time = torch.concat(193[torch.zeros((N, 1, 1)).to(self.device), avg_per_time[:, :, :-1]],194axis=2,195) # start @ 0196std_e = torch.sqrt(197avg_per_time + eps198) # square root causes some problems if MSE is 0199std_e = std_e.permute(0, 2, 1)200return std_e201
202def compute_forecasting_metrics(203self,204preds,205actual_batch,206eps=1,207periodicity=12,208series_mean=True,209rolling_eval=False,210min0=False,211):212"""Computes forecasting metrics.213
214Args:
215preds: predicted values
216actual_batch: batch with actual values
217eps: small constant for stability
218periodicity: number of timepoints in a period
219series_mean: whether to take mean across series
220rolling_eval: whether evaluation in performed in a roll-forward manner
221min0: whether to cut off predictions at 0 as the minimum (e.g. since
222negative demand is impossible)
223
224Returns:
225dictionary of forecasting metrics
226"""
227N, T, L, D = preds.shape228
229x_scale = actual_batch['x_scale']230x_offset = actual_batch['x_offset']231
232if 'eval_targets' in actual_batch: # dealing with windowed input233x_imputed = actual_batch['x'] # should have 144 timepoints for m3234unfolded_actual_imputed = actual_batch['eval_targets']235forecast_horizon_lengths, time_lengths = self._get_lengths_from_time_mask(236actual_batch['eval_target_times_mask']237)238else:239x_imputed = actual_batch['x_imputed']240unfolded_actual_imputed = actual_batch['unfolded_actual_imputed']241forecast_horizon_lengths = actual_batch['forecast_horizon_lengths']242time_lengths = actual_batch['time_lengths']243
244forecast_horizon_lengths = forecast_horizon_lengths[245:, self.first_cutoff :, :246]247time_lengths = time_lengths - self.first_cutoff248
249if self.scale01:250preds = self._rescale(preds, x_scale, x_offset)251x_imputed = self._rescale(x_imputed, x_scale, x_offset)252unfolded_actual_imputed = self._rescale(253unfolded_actual_imputed, x_scale, x_offset254)255if min0:256preds = torch.nn.functional.relu(preds)257if x_imputed.min() < 0 or unfolded_actual_imputed.min() < 0:258raise NotImplementedError(259'unexpected value in x_imputed or unfolded_actual_imputed'260)261
262test_actual = unfolded_actual_imputed[:, self.first_cutoff :, :]263test_preds = preds[:, self.first_cutoff :, :]264
265# MSE266mse = self.compute_mse(267test_preds,268test_actual,269forecast_horizon_lengths,270time_lengths,271series_mean=series_mean,272)273
274# MPE275mpe = get_ragged_mean(276(test_actual - test_preds) / (test_actual + eps),277lens=forecast_horizon_lengths,278axis=-2,279device=self.device,280)281mpe = get_ragged_mean(mpe, lens=time_lengths, axis=-2, device=self.device)282if series_mean:283mpe = mpe.mean()284
285# sMAPE286smape = get_ragged_mean(287(test_actual - test_preds).abs() * 2.0 / (test_actual.abs() + eps),288lens=forecast_horizon_lengths,289axis=-2,290device=self.device,291)292smape = get_ragged_mean(293smape, lens=time_lengths, axis=-2, device=self.device294)295
296if series_mean:297smape = smape.mean()298
299# MASE300ae = (unfolded_actual_imputed - preds).abs()301ae = ae[:, self.first_cutoff :, :]302
303if 'eval_targets' in actual_batch:304full_N, full_T, full_D = (305x_imputed.shape306) # expect x_imputed 2nd dim to match original timescale so times correct307scale = torch.zeros((full_N, full_T, full_D)).to(self.device)308scale[:, periodicity:] = (309x_imputed[:, periodicity:] - x_imputed[:, :-periodicity]310).abs()311scale = torch.cumsum(scale, dim=1)312
313scale_ct = torch.zeros((full_N, full_T, full_D)).to(self.device)314scale_ct[:, periodicity:] = (315torch.arange(1, full_T - periodicity + 1)316.unsqueeze(-1)317.unsqueeze(0)318.repeat(full_N, 1, full_D)319)320
321scale = scale / scale_ct322if rolling_eval: # each sample is actually a decoding point323num_start_ts, num_roll_ts, _, num_dims = ae.shape # t1, t2, l, d324
325# figure out scaling factor corresponding to each element of ae326start_ts = (actual_batch['target_times'][:, 0, 0] - 1).type(torch.int64)327scales_unrolled = torch.cat(328[329scale,330torch.ones(num_start_ts, num_roll_ts - 1, num_dims).to(331self.device332)333* 1e18,334],335axis=1,336).unfold(1, num_roll_ts, 1)337scales_unrolled = scales_unrolled.permute(0, 1, 3, 2)338scales_unrolled = torch.cat(339[340scales_unrolled,341torch.ones(342scales_unrolled.shape[0],3431,344scales_unrolled.shape[2],345scales_unrolled.shape[3],346).to(self.device),347],348axis=1,349)350
351start_ts = torch.clamp(start_ts, max=scales_unrolled.shape[1] - 1)352scale = torch.gather(353scales_unrolled,3541,355start_ts.unsqueeze(-2)356.unsqueeze(-2)357.repeat(1, 1, scales_unrolled.shape[2], 1),358).squeeze(1)359else:360first_cutoff = int(actual_batch['target_times'].min().item()) - 1361if 'max_t_cutoff' in actual_batch:362max_t_cutoff = actual_batch['max_t_cutoff']363scale = scale[:, first_cutoff:max_t_cutoff]364else:365scale = scale[:, first_cutoff:]366else:367assert periodicity <= self.first_cutoff368scale = torch.zeros((N, T)).to(self.device)369scale[:, periodicity:] = (370unfolded_actual_imputed[:, periodicity:, 0]371- unfolded_actual_imputed[:, :-periodicity, 0]372).abs()373scale = torch.cumsum(scale, dim=1)374
375scale_ct = torch.zeros((N, T)).to(self.device)376scale_ct[:, periodicity:] = torch.arange(1, T - periodicity + 1)377
378scale = scale / scale_ct379scale = scale[:, self.first_cutoff :]380
381nans = np.empty((N, L - 1, D))382nans[:] = np.nan383nans = torch.from_numpy(nans).float().to(self.device)384scale = torch.cat([scale, nans], dim=1).unfold(1, L, 1).permute(0, 1, 3, 2)385scale = torch.nan_to_num(scale, nan=1.0)386
387if scale.shape[1] < ae.shape[1]:388print(scale.shape, ae.shape)389ones = torch.ones(390ae.shape[0], ae.shape[1] - scale.shape[1], ae.shape[2], ae.shape[3]391)392scale = torch.cat([scale, ones], dim=1)393ase = ae / scale394
395mase = get_ragged_mean(396ase, lens=forecast_horizon_lengths, axis=-2, device=self.device397)398mase = get_ragged_mean(mase, lens=time_lengths, axis=-2, device=self.device)399if series_mean:400mase = mase.mean()401
402forecasting_metrics = {403'mse': mse,404'mpe': mpe,405'smape': smape,406'mase': mase,407}408return forecasting_metrics409
410def _get_lagged(self, matrix, lag=1, same_size=True):411N, _, D = matrix.shape # N x T x D412pad = torch.zeros((N, lag, D)).to(self.device)413lagged = torch.concat([pad, matrix], axis=1)414if same_size:415lagged = lagged[:, :-lag, :]416return lagged417
418def compute_inventory_metrics(419self,420preds,421actual_batch,422target_service_level=0.95,423unit_holding_cost=1,424unit_var_o_cost=1.0 / 100000.0,425unit_stockout_cost=1,426series_mean=True,427quantile_loss=None,428naive_metrics=None,429min0=False,430):431"""Computes inventory metrics.432
433Args:
434preds: predicted values
435actual_batch: batch with actual values
436target_service_level: service level to use for safety stock calculation
437unit_holding_cost: cost per unit held
438unit_var_o_cost: cost per unit order variance
439unit_stockout_cost: cost per unit stockout
440series_mean: whether to take mean across series
441quantile_loss: quantile loss objective, if relevant
442naive_metrics: baseline metrics
443min0: whether to cut off predictions at 0 as the minimum (e.g. since
444negative demand is impossible)
445
446Returns:
447
448"""
449
450x_scale = actual_batch['x_scale']451x_offset = actual_batch['x_offset']452_, _, lead_time, target_D = actual_batch['eval_targets'].shape # N, T, L, D453
454if 'eval_targets' in actual_batch: # dealing with windowed input455x_imputed = actual_batch['eval_targets'][:, :, 0]456unfolded_actual_imputed = actual_batch['eval_targets']457forecast_horizon_lengths, time_lengths = self._get_lengths_from_time_mask(458actual_batch['eval_target_times_mask']459)460unfolded_time_mask = actual_batch['eval_target_times_mask']461else:462x_imputed = actual_batch['x_imputed']463unfolded_actual_imputed = actual_batch['unfolded_actual_imputed']464time_lengths = actual_batch['time_lengths']465forecast_horizon_lengths = actual_batch['forecast_horizon_lengths']466unfolded_time_mask = actual_batch['unfolded_time_mask']467
468time_lengths = time_lengths - self.first_cutoff469if self.scale01:470preds = self._rescale(preds, x_scale, x_offset)471unfolded_actual_imputed = self._rescale(472unfolded_actual_imputed, x_scale, x_offset473)474x_imputed = self._rescale(x_imputed, x_scale, x_offset)475if min0:476preds = torch.nn.functional.relu(preds)477
478preds = preds * unfolded_time_mask[:, :, :, :, -1]479lead_forecasts = preds.sum(axis=2) # N x T480lead_forecasts = lead_forecasts[:, self.first_cutoff :]481if quantile_loss or self.no_safety_stock:482safety_stocks = torch.zeros(lead_forecasts.shape).to(self.device)483else:484std_e = self._get_std_e(485preds,486unfolded_actual_imputed,487unfolded_time_mask,488eps=1e-5,489) # N x T490std_e = std_e * lead_time # approximate lead time std_e491safety_const = norm.ppf(target_service_level)492safety_stocks = safety_const * std_e # N x T493
494inventory_positions = (495self._get_lagged(lead_forecasts)496+ self._get_lagged(safety_stocks)497- x_imputed[:, self.first_cutoff :]498)499
500orders = lead_forecasts + safety_stocks - inventory_positions501
502recent_demand = (503self._get_lagged(x_imputed, lag=lead_time - 1, same_size=False)504.unfold(1, lead_time, 1)505.permute(0, 1, 3, 2)506)507# works because there's at least lead time worth of real obs508recent_horizon_lengths = torch.cat(509[510torch.ones(x_imputed.shape[0], lead_time - 1, 1, target_D).to(511self.device512)513* lead_time,514forecast_horizon_lengths[:, : -(lead_time - 1), :, :],515],516axis=1,517).type(torch.int64)518recent_demand = get_ragged_sum(519recent_demand, recent_horizon_lengths, device=self.device, axis=2520)521
522net_inventory_levels = (523self._get_lagged(lead_forecasts, lag=lead_time)524+ self._get_lagged(safety_stocks, lag=lead_time)525- recent_demand526)527
528work_in_progress = inventory_positions - net_inventory_levels529
530holding_cost = (531torch.nn.functional.relu(net_inventory_levels) * unit_holding_cost532)533holding_cost = get_ragged_mean(534holding_cost, time_lengths, device=self.device, axis=1535)536if series_mean:537holding_cost = holding_cost.mean()538
539soft_holding_cost = (540torch.nn.functional.softplus(net_inventory_levels) * unit_holding_cost541)542soft_holding_cost = get_ragged_mean(543soft_holding_cost, time_lengths, device=self.device, axis=1544)545if series_mean:546soft_holding_cost = soft_holding_cost.mean()547
548var_o = get_ragged_var(549orders,550torch.maximum(time_lengths, torch.Tensor([0]).to(self.device)).type(551torch.int64552),553device=self.device,554axis=1,555) # avg of variance of orders for each series556if series_mean:557var_o = var_o.mean()558
559var_o_cost = var_o * unit_var_o_cost560
561# proportion of orders that are negative562prop_neg_orders = get_ragged_mean(563(orders < 0).float(), time_lengths, device=self.device, axis=1564)565if series_mean:566prop_neg_orders = prop_neg_orders.mean()567
568# how often stockout occurs569achieved_service_level = get_ragged_mean(570(net_inventory_levels >= 0).float(),571time_lengths,572device=self.device,573axis=1,574)575if series_mean:576achieved_service_level = achieved_service_level.mean()577
578soft_alpha = torch.sigmoid(net_inventory_levels * 1e2)579soft_alpha = get_ragged_mean(580soft_alpha, time_lengths, device=self.device, axis=1581)582if series_mean:583soft_alpha = soft_alpha.mean()584
585# stockout cost586stockout_cost = (587torch.nn.functional.relu(-net_inventory_levels) * unit_stockout_cost588)589stockout_cost = get_ragged_mean(590stockout_cost, time_lengths, device=self.device, axis=1591)592if series_mean:593stockout_cost = stockout_cost.mean()594
595# compute rms596rms = torch.sqrt(597(598holding_cost**2599+ var_o**2600+ (1.0 / (achieved_service_level + 1e-5)) ** 2601)602/ 3.0603)604
605# cost606total_cost = holding_cost + stockout_cost + var_o_cost607
608inventory_values = {609'inventory_positions': inventory_positions,610'net_inventory_levels': net_inventory_levels,611'work_in_progress': work_in_progress,612'safety_stocks': safety_stocks,613'orders': orders,614'lead_forecasts': lead_forecasts,615'unfolded_actual_imputed': unfolded_actual_imputed,616'unfolded_time_mask': unfolded_time_mask,617'time_lengths': time_lengths,618'demand': x_imputed,619}620
621inventory_metrics = {622'holding_cost': holding_cost,623'soft_holding_cost': soft_holding_cost,624'var_o': var_o,625'var_o_cost': var_o_cost,626'prop_neg_orders': prop_neg_orders,627'achieved_service_level': achieved_service_level,628'soft_achieved_service_level': soft_alpha,629'stockout_cost': stockout_cost,630'rms': rms,631'total_cost': total_cost,632'inventory_values': inventory_values,633}634
635# compute scaled_rms636if naive_metrics:637scaled_holding_cost = holding_cost / (naive_metrics['holding_cost'] + 1)638scaled_var_o = var_o / (naive_metrics['var_o'] + 1)639scaled_achieved_service_level = achieved_service_level / (640naive_metrics['achieved_service_level'] + 0.1641)642scaled_rms = torch.sqrt(643(644scaled_holding_cost**2645+ scaled_var_o**2646+ (1.0 / (scaled_achieved_service_level + 0.1)) ** 2647)648/ 3.0649)650
651rel_holding_cost = (holding_cost - naive_metrics['holding_cost']) / (652naive_metrics['holding_cost'] + 1653)654rel_var_o = (var_o - naive_metrics['var_o']) / (655naive_metrics['var_o'] + 1656)657rel_achieved_service_level = (658(1.0 / (achieved_service_level + 0.1))659- (1.0 / (naive_metrics['achieved_service_level'] + 0.1))660) / (1.0 / (naive_metrics['achieved_service_level'] + 0.1))661rel_stockout_cost = (stockout_cost - naive_metrics['stockout_cost']) / (662naive_metrics['stockout_cost'] + 1663)664rel_rms_avg = (665torch.sigmoid(rel_holding_cost)666+ torch.sigmoid(rel_var_o)667+ torch.sigmoid(rel_achieved_service_level)668) / 3.0669rel_rms_2 = (670(torch.sigmoid(rel_holding_cost) ** 2)671+ (torch.sigmoid(rel_var_o) ** 2)672+ (torch.sigmoid(rel_achieved_service_level) ** 2)673)674rel_rms_3 = (675(torch.sigmoid(rel_holding_cost) ** 3)676+ (torch.sigmoid(rel_var_o) ** 3)677+ (torch.sigmoid(rel_achieved_service_level) ** 3)678)679rel_rms_5 = (680(torch.sigmoid(rel_holding_cost) ** 5)681+ (torch.sigmoid(rel_var_o) ** 5)682+ (torch.sigmoid(rel_achieved_service_level) ** 5)683)684rel_rms_logsumexp = torch.logsumexp(685torch.cat(686[687torch.sigmoid(rel_holding_cost).unsqueeze(0),688torch.sigmoid(rel_var_o).unsqueeze(0),689torch.sigmoid(rel_achieved_service_level).unsqueeze(0),690],691dim=0,692),693dim=0,694)695
696rel_rms_stockout_2 = (697(torch.sigmoid(rel_holding_cost) ** 2)698+ (torch.sigmoid(rel_var_o) ** 2)699+ (torch.sigmoid(rel_stockout_cost) ** 2)700)701rel_rms_stockout_3 = (702(torch.sigmoid(rel_holding_cost) ** 3)703+ (torch.sigmoid(rel_var_o) ** 3)704+ (torch.sigmoid(rel_stockout_cost) ** 3)705)706rel_rms_stockout_5 = (707(torch.sigmoid(rel_holding_cost) ** 5)708+ (torch.sigmoid(rel_var_o) ** 5)709+ (torch.sigmoid(rel_stockout_cost) ** 5)710)711
712inventory_metrics['scaled_rms'] = scaled_rms713inventory_metrics['rel_rms_avg'] = rel_rms_avg714inventory_metrics['rel_rms_2'] = rel_rms_2715inventory_metrics['rel_rms_3'] = rel_rms_3716inventory_metrics['rel_rms_5'] = rel_rms_5717inventory_metrics['rel_rms_logsumexp'] = rel_rms_logsumexp718inventory_metrics['rel_rms_stockout_2'] = rel_rms_stockout_2719inventory_metrics['rel_rms_stockout_3'] = rel_rms_stockout_3720inventory_metrics['rel_rms_stockout_5'] = rel_rms_stockout_5721
722return inventory_metrics723
724def compute_all_metrics(725self,726preds,727actual_batch,728target_service_level,729unit_holding_cost,730unit_stockout_cost,731unit_var_o_cost,732series_mean,733quantile_loss,734naive_model,735scale_by_naive_model=False,736rolling_eval=False,737min0=False,738):739"""Given predictions, computes all metrics of interest.740
741Args:
742preds: predicted values
743actual_batch: batch with actual values
744target_service_level: service level to use for safety stock calculation
745unit_holding_cost: cost per unit held
746unit_stockout_cost: cost per unit stockout
747unit_var_o_cost: cost per unit order variance
748series_mean: whether to take mean across series
749quantile_loss: quantile loss objective, if relevant
750naive_model: baseline model
751scale_by_naive_model: whether to scale performance by baseline model
752rolling_eval: whether evaluation is roll-forward
753min0: whether to cut off predictions at 0 as the minimum (e.g. since
754negative demand is impossible)
755
756Returns:
757"""
758actual_batch = self._extract_target_dims(actual_batch)759
760all_metrics = {}761
762_, T, _, _ = preds.shape # N x T x L x D763
764immediate_series_mean = False765
766# compute naive model metrics767naive_all_metrics = {}768with torch.no_grad():769naive_preds = naive_model(actual_batch, in_eval=True)770naive_preds = naive_preds[:, :T, :, :]771naive_inventory_metrics = self.compute_inventory_metrics(772naive_preds,773actual_batch,774target_service_level=target_service_level,775unit_holding_cost=unit_holding_cost,776unit_stockout_cost=unit_stockout_cost,777unit_var_o_cost=unit_var_o_cost,778series_mean=immediate_series_mean,779)780naive_forecasting_metrics = self.compute_forecasting_metrics(781naive_preds,782actual_batch,783series_mean=immediate_series_mean,784rolling_eval=rolling_eval,785)786naive_all_metrics.update(naive_inventory_metrics)787naive_all_metrics.update(naive_forecasting_metrics)788
789# compute inventory metrics790inventory_metrics = self.compute_inventory_metrics(791preds,792actual_batch,793target_service_level=target_service_level,794unit_holding_cost=unit_holding_cost,795unit_stockout_cost=unit_stockout_cost,796unit_var_o_cost=unit_var_o_cost,797series_mean=immediate_series_mean,798quantile_loss=quantile_loss,799naive_metrics=naive_all_metrics,800)801
802for metric_name, metric_val in inventory_metrics.items():803if metric_name == 'inventory_values':804all_metrics[metric_name] = metric_val805continue806if scale_by_naive_model:807metric_val = metric_val / (naive_all_metrics[metric_name] + 1e-5)808if (not immediate_series_mean) and series_mean:809metric_val = metric_val.mean()810all_metrics[metric_name] = metric_val811
812# compute forecasting metrics813forecasting_metrics = self.compute_forecasting_metrics(814preds,815actual_batch,816series_mean=immediate_series_mean,817rolling_eval=rolling_eval,818)819
820for metric_name, metric_val in forecasting_metrics.items():821if scale_by_naive_model and metric_name != 'mpe':822metric_val = metric_val / (naive_all_metrics[metric_name] + 1e-5)823if (not immediate_series_mean) and series_mean:824metric_val = metric_val.mean()825all_metrics[metric_name] = metric_val826
827# add quantile loss828if quantile_loss:829if self.scale01:830preds = self._rescale(831preds, actual_batch['x_scale'], actual_batch['x_offset']832)833targets = self._rescale(834actual_batch['eval_targets'],835actual_batch['x_scale'],836actual_batch['x_offset'],837)838if min0:839preds = torch.nn.functional.relu(preds)840forecast_horizon_lengths, time_lengths = self._get_lengths_from_time_mask(841actual_batch['eval_target_times_mask']842)843qloss = quantile_loss(844preds, targets, forecast_horizon_lengths, time_lengths845)846if series_mean:847qloss = qloss.mean()848all_metrics['quantile_loss'] = qloss849
850return all_metrics851