google-research

Форк
0
850 строк · 27.2 Кб
1
# coding=utf-8
2
# Copyright 2024 The Google Research Authors.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
"""Evaluator object.
17

18
Computes all evaluation metrics.
19

20
Evaluation metrics include forecasting and inventory performance metrics.
21
"""
22

23
# allow capital letter names for dimensions to improve clarity (e.g. N, T, D)
24
# pylint: disable=invalid-name
25

26
import numpy as np
27
from scipy.stats import norm
28
import torch
29
from utils.eval_utils import get_ragged_mean
30
from utils.eval_utils import get_ragged_sum
31
from utils.eval_utils import get_ragged_var
32

33

34
class Evaluator(object):
35
  """Evaluator class. Handles differentiable computation of metrics."""
36

37
  def __init__(
38
      self, first_cutoff, scale01, device, target_dims, no_safety_stock
39
  ):
40
    self.first_cutoff = first_cutoff
41
    self.scale01 = scale01
42
    self.device = device
43
    self.target_dims = target_dims
44
    self.no_safety_stock = no_safety_stock
45

46
  def _extract_target_dims(self, batch):
47
    """Extracts the slices corresponding to the target of interest.
48

49
    Args:
50
      batch: the batch data dictionary
51

52
    Returns:
53
      the batch data dictionary, sliced to extract the target
54
    """
55
    new_batch = batch.copy()
56
    for k, v in batch.items():
57
      if k in [
58
          'x',
59
          'x_scale',
60
          'x_offset',
61
          'model_inputs',
62
          'model_targets',
63
          'eval_inputs',
64
          'eval_targets',
65
      ]:
66
        if len(v.shape) == 2:
67
          v = v[:, self.target_dims]
68
        elif len(v.shape) == 3:
69
          v = v[:, :, self.target_dims]
70
        elif len(v.shape) == 4:
71
          v = v[:, :, :, self.target_dims]
72
        else:
73
          raise NotImplementedError('Unexpected number of dims: ', v.shape)
74
      new_batch[k] = v
75
    return new_batch
76

77
  def _rescale(self, arr, x_scale, x_offset):
78
    """Scale the array back up to its original values.
79

80
    Args:
81
      arr: scaled array
82
      x_scale: scale
83
      x_offset: offset
84

85
    Returns:
86
      array in its original range of values
87
    """
88
    if not x_scale.shape:
89
      return (arr * x_scale) + x_offset
90
    shape = arr.shape
91
    assert shape[0] == x_scale.shape[0]
92

93
    # repeat scale and offset to match up with arr shape
94
    to_expand = len(shape) - len(x_scale.shape)
95
    for _ in range(to_expand):
96
      x_scale = x_scale.unsqueeze(-1)
97
      x_offset = x_offset.unsqueeze(-1)
98
    x_scale = x_scale.repeat(1, *shape[1:])
99
    x_offset = x_offset.repeat(1, *shape[1:])
100
    return (arr * x_scale) + x_offset
101

102
  def _get_lengths_from_time_mask(self, time_mask):
103
    forecast_horizon_lengths = (
104
        time_mask[:, :, :, :, -1].sum(dim=2).unsqueeze(2).type(torch.int64)
105
    )  # D is always last
106
    time_lengths = (
107
        (time_mask[:, :, 0, :, -1] > 0)
108
        .float()
109
        .sum(dim=1)
110
        .unsqueeze(1)
111
        .type(torch.int64)
112
    )
113
    return forecast_horizon_lengths, time_lengths
114

115
  def compute_mse(
116
      self,
117
      preds,
118
      unfolded_actual_imputed,
119
      forecast_horizon_lengths,
120
      time_lengths,
121
      series_mean=True,
122
  ):
123
    """Compute the mean squared error, taking sequence lengths into account.
124

125
    Args:
126
      preds: predictions tensor
127
      unfolded_actual_imputed: actual values tensor, unfolded to be the same
128
        shape as predictions, and imputed to avoid issues with autodiff
129
      forecast_horizon_lengths: lengths of each forecast horizon
130
      time_lengths: lengths of each series
131
      series_mean: whether to take the mean across series
132

133
    Returns:
134
      mean squared error
135
    """
136
    squared_errs = (preds - unfolded_actual_imputed) ** 2  # N x T x L x D
137

138
    # handle first cutoff
139
    squared_errs = squared_errs[:, self.first_cutoff :, :]
140
    forecast_horizon_lengths = forecast_horizon_lengths[
141
        :, self.first_cutoff :, :
142
    ]
143
    time_lengths = time_lengths - self.first_cutoff
144

145
    # get average along forecasting horizon
146
    mse = get_ragged_mean(
147
        squared_errs, lens=forecast_horizon_lengths, axis=-2, device=self.device
148
    )
149
    # get average along time
150
    mse = get_ragged_mean(mse, lens=time_lengths, axis=-2, device=self.device)
151
    # get average across all series
152
    if series_mean:
153
      mse = mse.mean()
154
    return mse
155

156
  def _get_std_e(
157
      self,
158
      preds,
159
      unfolded_actual_imputed,
160
      unfolded_time_mask,
161
      eps=1e-5,
162
  ):
163
    """Compute the standard deviation over previous forecast errors.
164

165
    Args:
166
      preds: predictions tensor (N x T x L x D)
167
      unfolded_actual_imputed: actual values tensor, unfolded to be the same
168
        shape as predictions, and imputed to avoid issues with autodiff
169
      unfolded_time_mask: times mask, the [:,:,:,t] slice corresponds to whether
170
        the corresponding timepoint has passed
171
      eps: small constant for stability
172

173
    Returns:
174
      tensor of standard deviations
175
    """
176
    N, T, _, _ = preds.shape
177

178
    squared_errs = (preds - unfolded_actual_imputed) ** 2
179
    squared_errs = squared_errs.unsqueeze(-1).repeat(1, 1, 1, 1, T)
180
    masked_errs = squared_errs * unfolded_time_mask
181

182
    # handle first cutoff
183
    masked_errs = masked_errs[:, :, :, :, self.first_cutoff :]
184
    mask = unfolded_time_mask[:, :, :, :, self.first_cutoff :]
185
    mask_denom_nonzero = mask.sum(2).sum(1)  # takes errors per timestep
186
    mask_denom_nonzero = (mask_denom_nonzero != 0).float()
187
    mask_denom_nonzero = mask.sum(2).sum(1) + (
188
        1 - mask_denom_nonzero
189
    )  # fills in a 1 wherever it's 0
190

191
    avg_per_time = masked_errs.sum(2).sum(1) / mask_denom_nonzero  # N x T
192
    avg_per_time = torch.concat(
193
        [torch.zeros((N, 1, 1)).to(self.device), avg_per_time[:, :, :-1]],
194
        axis=2,
195
    )  # start @ 0
196
    std_e = torch.sqrt(
197
        avg_per_time + eps
198
    )  # square root causes some problems if MSE is 0
199
    std_e = std_e.permute(0, 2, 1)
200
    return std_e
201

202
  def compute_forecasting_metrics(
203
      self,
204
      preds,
205
      actual_batch,
206
      eps=1,
207
      periodicity=12,
208
      series_mean=True,
209
      rolling_eval=False,
210
      min0=False,
211
  ):
212
    """Computes forecasting metrics.
213

214
    Args:
215
      preds: predicted values
216
      actual_batch: batch with actual values
217
      eps: small constant for stability
218
      periodicity: number of timepoints in a period
219
      series_mean: whether to take mean across series
220
      rolling_eval: whether evaluation in performed in a roll-forward manner
221
      min0: whether to cut off predictions at 0 as the minimum (e.g. since
222
        negative demand is impossible)
223

224
    Returns:
225
      dictionary of forecasting metrics
226
    """
227
    N, T, L, D = preds.shape
228

229
    x_scale = actual_batch['x_scale']
230
    x_offset = actual_batch['x_offset']
231

232
    if 'eval_targets' in actual_batch:  # dealing with windowed input
233
      x_imputed = actual_batch['x']  # should have 144 timepoints for m3
234
      unfolded_actual_imputed = actual_batch['eval_targets']
235
      forecast_horizon_lengths, time_lengths = self._get_lengths_from_time_mask(
236
          actual_batch['eval_target_times_mask']
237
      )
238
    else:
239
      x_imputed = actual_batch['x_imputed']
240
      unfolded_actual_imputed = actual_batch['unfolded_actual_imputed']
241
      forecast_horizon_lengths = actual_batch['forecast_horizon_lengths']
242
      time_lengths = actual_batch['time_lengths']
243

244
    forecast_horizon_lengths = forecast_horizon_lengths[
245
        :, self.first_cutoff :, :
246
    ]
247
    time_lengths = time_lengths - self.first_cutoff
248

249
    if self.scale01:
250
      preds = self._rescale(preds, x_scale, x_offset)
251
      x_imputed = self._rescale(x_imputed, x_scale, x_offset)
252
      unfolded_actual_imputed = self._rescale(
253
          unfolded_actual_imputed, x_scale, x_offset
254
      )
255
      if min0:
256
        preds = torch.nn.functional.relu(preds)
257
    if x_imputed.min() < 0 or unfolded_actual_imputed.min() < 0:
258
      raise NotImplementedError(
259
          'unexpected value in x_imputed or unfolded_actual_imputed'
260
      )
261

262
    test_actual = unfolded_actual_imputed[:, self.first_cutoff :, :]
263
    test_preds = preds[:, self.first_cutoff :, :]
264

265
    # MSE
266
    mse = self.compute_mse(
267
        test_preds,
268
        test_actual,
269
        forecast_horizon_lengths,
270
        time_lengths,
271
        series_mean=series_mean,
272
    )
273

274
    # MPE
275
    mpe = get_ragged_mean(
276
        (test_actual - test_preds) / (test_actual + eps),
277
        lens=forecast_horizon_lengths,
278
        axis=-2,
279
        device=self.device,
280
    )
281
    mpe = get_ragged_mean(mpe, lens=time_lengths, axis=-2, device=self.device)
282
    if series_mean:
283
      mpe = mpe.mean()
284

285
    # sMAPE
286
    smape = get_ragged_mean(
287
        (test_actual - test_preds).abs() * 2.0 / (test_actual.abs() + eps),
288
        lens=forecast_horizon_lengths,
289
        axis=-2,
290
        device=self.device,
291
    )
292
    smape = get_ragged_mean(
293
        smape, lens=time_lengths, axis=-2, device=self.device
294
    )
295

296
    if series_mean:
297
      smape = smape.mean()
298

299
    # MASE
300
    ae = (unfolded_actual_imputed - preds).abs()
301
    ae = ae[:, self.first_cutoff :, :]
302

303
    if 'eval_targets' in actual_batch:
304
      full_N, full_T, full_D = (
305
          x_imputed.shape
306
      )  # expect x_imputed 2nd dim to match original timescale so times correct
307
      scale = torch.zeros((full_N, full_T, full_D)).to(self.device)
308
      scale[:, periodicity:] = (
309
          x_imputed[:, periodicity:] - x_imputed[:, :-periodicity]
310
      ).abs()
311
      scale = torch.cumsum(scale, dim=1)
312

313
      scale_ct = torch.zeros((full_N, full_T, full_D)).to(self.device)
314
      scale_ct[:, periodicity:] = (
315
          torch.arange(1, full_T - periodicity + 1)
316
          .unsqueeze(-1)
317
          .unsqueeze(0)
318
          .repeat(full_N, 1, full_D)
319
      )
320

321
      scale = scale / scale_ct
322
      if rolling_eval:  # each sample is actually a decoding point
323
        num_start_ts, num_roll_ts, _, num_dims = ae.shape  # t1, t2, l, d
324

325
        # figure out scaling factor corresponding to each element of ae
326
        start_ts = (actual_batch['target_times'][:, 0, 0] - 1).type(torch.int64)
327
        scales_unrolled = torch.cat(
328
            [
329
                scale,
330
                torch.ones(num_start_ts, num_roll_ts - 1, num_dims).to(
331
                    self.device
332
                )
333
                * 1e18,
334
            ],
335
            axis=1,
336
        ).unfold(1, num_roll_ts, 1)
337
        scales_unrolled = scales_unrolled.permute(0, 1, 3, 2)
338
        scales_unrolled = torch.cat(
339
            [
340
                scales_unrolled,
341
                torch.ones(
342
                    scales_unrolled.shape[0],
343
                    1,
344
                    scales_unrolled.shape[2],
345
                    scales_unrolled.shape[3],
346
                ).to(self.device),
347
            ],
348
            axis=1,
349
        )
350

351
        start_ts = torch.clamp(start_ts, max=scales_unrolled.shape[1] - 1)
352
        scale = torch.gather(
353
            scales_unrolled,
354
            1,
355
            start_ts.unsqueeze(-2)
356
            .unsqueeze(-2)
357
            .repeat(1, 1, scales_unrolled.shape[2], 1),
358
        ).squeeze(1)
359
      else:
360
        first_cutoff = int(actual_batch['target_times'].min().item()) - 1
361
        if 'max_t_cutoff' in actual_batch:
362
          max_t_cutoff = actual_batch['max_t_cutoff']
363
          scale = scale[:, first_cutoff:max_t_cutoff]
364
        else:
365
          scale = scale[:, first_cutoff:]
366
    else:
367
      assert periodicity <= self.first_cutoff
368
      scale = torch.zeros((N, T)).to(self.device)
369
      scale[:, periodicity:] = (
370
          unfolded_actual_imputed[:, periodicity:, 0]
371
          - unfolded_actual_imputed[:, :-periodicity, 0]
372
      ).abs()
373
      scale = torch.cumsum(scale, dim=1)
374

375
      scale_ct = torch.zeros((N, T)).to(self.device)
376
      scale_ct[:, periodicity:] = torch.arange(1, T - periodicity + 1)
377

378
      scale = scale / scale_ct
379
      scale = scale[:, self.first_cutoff :]
380

381
    nans = np.empty((N, L - 1, D))
382
    nans[:] = np.nan
383
    nans = torch.from_numpy(nans).float().to(self.device)
384
    scale = torch.cat([scale, nans], dim=1).unfold(1, L, 1).permute(0, 1, 3, 2)
385
    scale = torch.nan_to_num(scale, nan=1.0)
386

387
    if scale.shape[1] < ae.shape[1]:
388
      print(scale.shape, ae.shape)
389
      ones = torch.ones(
390
          ae.shape[0], ae.shape[1] - scale.shape[1], ae.shape[2], ae.shape[3]
391
      )
392
      scale = torch.cat([scale, ones], dim=1)
393
    ase = ae / scale
394

395
    mase = get_ragged_mean(
396
        ase, lens=forecast_horizon_lengths, axis=-2, device=self.device
397
    )
398
    mase = get_ragged_mean(mase, lens=time_lengths, axis=-2, device=self.device)
399
    if series_mean:
400
      mase = mase.mean()
401

402
    forecasting_metrics = {
403
        'mse': mse,
404
        'mpe': mpe,
405
        'smape': smape,
406
        'mase': mase,
407
    }
408
    return forecasting_metrics
409

410
  def _get_lagged(self, matrix, lag=1, same_size=True):
411
    N, _, D = matrix.shape  # N x T x D
412
    pad = torch.zeros((N, lag, D)).to(self.device)
413
    lagged = torch.concat([pad, matrix], axis=1)
414
    if same_size:
415
      lagged = lagged[:, :-lag, :]
416
    return lagged
417

418
  def compute_inventory_metrics(
419
      self,
420
      preds,
421
      actual_batch,
422
      target_service_level=0.95,
423
      unit_holding_cost=1,
424
      unit_var_o_cost=1.0 / 100000.0,
425
      unit_stockout_cost=1,
426
      series_mean=True,
427
      quantile_loss=None,
428
      naive_metrics=None,
429
      min0=False,
430
  ):
431
    """Computes inventory metrics.
432

433
    Args:
434
      preds: predicted values
435
      actual_batch: batch with actual values
436
      target_service_level: service level to use for safety stock calculation
437
      unit_holding_cost: cost per unit held
438
      unit_var_o_cost: cost per unit order variance
439
      unit_stockout_cost: cost per unit stockout
440
      series_mean:  whether to take mean across series
441
      quantile_loss: quantile loss objective, if relevant
442
      naive_metrics: baseline metrics
443
      min0: whether to cut off predictions at 0 as the minimum (e.g. since
444
        negative demand is impossible)
445

446
    Returns:
447

448
    """
449

450
    x_scale = actual_batch['x_scale']
451
    x_offset = actual_batch['x_offset']
452
    _, _, lead_time, target_D = actual_batch['eval_targets'].shape  # N, T, L, D
453

454
    if 'eval_targets' in actual_batch:  # dealing with windowed input
455
      x_imputed = actual_batch['eval_targets'][:, :, 0]
456
      unfolded_actual_imputed = actual_batch['eval_targets']
457
      forecast_horizon_lengths, time_lengths = self._get_lengths_from_time_mask(
458
          actual_batch['eval_target_times_mask']
459
      )
460
      unfolded_time_mask = actual_batch['eval_target_times_mask']
461
    else:
462
      x_imputed = actual_batch['x_imputed']
463
      unfolded_actual_imputed = actual_batch['unfolded_actual_imputed']
464
      time_lengths = actual_batch['time_lengths']
465
      forecast_horizon_lengths = actual_batch['forecast_horizon_lengths']
466
      unfolded_time_mask = actual_batch['unfolded_time_mask']
467

468
    time_lengths = time_lengths - self.first_cutoff
469
    if self.scale01:
470
      preds = self._rescale(preds, x_scale, x_offset)
471
      unfolded_actual_imputed = self._rescale(
472
          unfolded_actual_imputed, x_scale, x_offset
473
      )
474
      x_imputed = self._rescale(x_imputed, x_scale, x_offset)
475
      if min0:
476
        preds = torch.nn.functional.relu(preds)
477

478
    preds = preds * unfolded_time_mask[:, :, :, :, -1]
479
    lead_forecasts = preds.sum(axis=2)  # N x T
480
    lead_forecasts = lead_forecasts[:, self.first_cutoff :]
481
    if quantile_loss or self.no_safety_stock:
482
      safety_stocks = torch.zeros(lead_forecasts.shape).to(self.device)
483
    else:
484
      std_e = self._get_std_e(
485
          preds,
486
          unfolded_actual_imputed,
487
          unfolded_time_mask,
488
          eps=1e-5,
489
      )  # N x T
490
      std_e = std_e * lead_time  # approximate lead time std_e
491
      safety_const = norm.ppf(target_service_level)
492
      safety_stocks = safety_const * std_e  # N x T
493

494
    inventory_positions = (
495
        self._get_lagged(lead_forecasts)
496
        + self._get_lagged(safety_stocks)
497
        - x_imputed[:, self.first_cutoff :]
498
    )
499

500
    orders = lead_forecasts + safety_stocks - inventory_positions
501

502
    recent_demand = (
503
        self._get_lagged(x_imputed, lag=lead_time - 1, same_size=False)
504
        .unfold(1, lead_time, 1)
505
        .permute(0, 1, 3, 2)
506
    )
507
    # works because there's at least lead time worth of real obs
508
    recent_horizon_lengths = torch.cat(
509
        [
510
            torch.ones(x_imputed.shape[0], lead_time - 1, 1, target_D).to(
511
                self.device
512
            )
513
            * lead_time,
514
            forecast_horizon_lengths[:, : -(lead_time - 1), :, :],
515
        ],
516
        axis=1,
517
    ).type(torch.int64)
518
    recent_demand = get_ragged_sum(
519
        recent_demand, recent_horizon_lengths, device=self.device, axis=2
520
    )
521

522
    net_inventory_levels = (
523
        self._get_lagged(lead_forecasts, lag=lead_time)
524
        + self._get_lagged(safety_stocks, lag=lead_time)
525
        - recent_demand
526
    )
527

528
    work_in_progress = inventory_positions - net_inventory_levels
529

530
    holding_cost = (
531
        torch.nn.functional.relu(net_inventory_levels) * unit_holding_cost
532
    )
533
    holding_cost = get_ragged_mean(
534
        holding_cost, time_lengths, device=self.device, axis=1
535
    )
536
    if series_mean:
537
      holding_cost = holding_cost.mean()
538

539
    soft_holding_cost = (
540
        torch.nn.functional.softplus(net_inventory_levels) * unit_holding_cost
541
    )
542
    soft_holding_cost = get_ragged_mean(
543
        soft_holding_cost, time_lengths, device=self.device, axis=1
544
    )
545
    if series_mean:
546
      soft_holding_cost = soft_holding_cost.mean()
547

548
    var_o = get_ragged_var(
549
        orders,
550
        torch.maximum(time_lengths, torch.Tensor([0]).to(self.device)).type(
551
            torch.int64
552
        ),
553
        device=self.device,
554
        axis=1,
555
    )  # avg of variance of orders for each series
556
    if series_mean:
557
      var_o = var_o.mean()
558

559
    var_o_cost = var_o * unit_var_o_cost
560

561
    # proportion of orders that are negative
562
    prop_neg_orders = get_ragged_mean(
563
        (orders < 0).float(), time_lengths, device=self.device, axis=1
564
    )
565
    if series_mean:
566
      prop_neg_orders = prop_neg_orders.mean()
567

568
    # how often stockout occurs
569
    achieved_service_level = get_ragged_mean(
570
        (net_inventory_levels >= 0).float(),
571
        time_lengths,
572
        device=self.device,
573
        axis=1,
574
    )
575
    if series_mean:
576
      achieved_service_level = achieved_service_level.mean()
577

578
    soft_alpha = torch.sigmoid(net_inventory_levels * 1e2)
579
    soft_alpha = get_ragged_mean(
580
        soft_alpha, time_lengths, device=self.device, axis=1
581
    )
582
    if series_mean:
583
      soft_alpha = soft_alpha.mean()
584

585
    # stockout cost
586
    stockout_cost = (
587
        torch.nn.functional.relu(-net_inventory_levels) * unit_stockout_cost
588
    )
589
    stockout_cost = get_ragged_mean(
590
        stockout_cost, time_lengths, device=self.device, axis=1
591
    )
592
    if series_mean:
593
      stockout_cost = stockout_cost.mean()
594

595
    # compute rms
596
    rms = torch.sqrt(
597
        (
598
            holding_cost**2
599
            + var_o**2
600
            + (1.0 / (achieved_service_level + 1e-5)) ** 2
601
        )
602
        / 3.0
603
    )
604

605
    # cost
606
    total_cost = holding_cost + stockout_cost + var_o_cost
607

608
    inventory_values = {
609
        'inventory_positions': inventory_positions,
610
        'net_inventory_levels': net_inventory_levels,
611
        'work_in_progress': work_in_progress,
612
        'safety_stocks': safety_stocks,
613
        'orders': orders,
614
        'lead_forecasts': lead_forecasts,
615
        'unfolded_actual_imputed': unfolded_actual_imputed,
616
        'unfolded_time_mask': unfolded_time_mask,
617
        'time_lengths': time_lengths,
618
        'demand': x_imputed,
619
    }
620

621
    inventory_metrics = {
622
        'holding_cost': holding_cost,
623
        'soft_holding_cost': soft_holding_cost,
624
        'var_o': var_o,
625
        'var_o_cost': var_o_cost,
626
        'prop_neg_orders': prop_neg_orders,
627
        'achieved_service_level': achieved_service_level,
628
        'soft_achieved_service_level': soft_alpha,
629
        'stockout_cost': stockout_cost,
630
        'rms': rms,
631
        'total_cost': total_cost,
632
        'inventory_values': inventory_values,
633
    }
634

635
    # compute scaled_rms
636
    if naive_metrics:
637
      scaled_holding_cost = holding_cost / (naive_metrics['holding_cost'] + 1)
638
      scaled_var_o = var_o / (naive_metrics['var_o'] + 1)
639
      scaled_achieved_service_level = achieved_service_level / (
640
          naive_metrics['achieved_service_level'] + 0.1
641
      )
642
      scaled_rms = torch.sqrt(
643
          (
644
              scaled_holding_cost**2
645
              + scaled_var_o**2
646
              + (1.0 / (scaled_achieved_service_level + 0.1)) ** 2
647
          )
648
          / 3.0
649
      )
650

651
      rel_holding_cost = (holding_cost - naive_metrics['holding_cost']) / (
652
          naive_metrics['holding_cost'] + 1
653
      )
654
      rel_var_o = (var_o - naive_metrics['var_o']) / (
655
          naive_metrics['var_o'] + 1
656
      )
657
      rel_achieved_service_level = (
658
          (1.0 / (achieved_service_level + 0.1))
659
          - (1.0 / (naive_metrics['achieved_service_level'] + 0.1))
660
      ) / (1.0 / (naive_metrics['achieved_service_level'] + 0.1))
661
      rel_stockout_cost = (stockout_cost - naive_metrics['stockout_cost']) / (
662
          naive_metrics['stockout_cost'] + 1
663
      )
664
      rel_rms_avg = (
665
          torch.sigmoid(rel_holding_cost)
666
          + torch.sigmoid(rel_var_o)
667
          + torch.sigmoid(rel_achieved_service_level)
668
      ) / 3.0
669
      rel_rms_2 = (
670
          (torch.sigmoid(rel_holding_cost) ** 2)
671
          + (torch.sigmoid(rel_var_o) ** 2)
672
          + (torch.sigmoid(rel_achieved_service_level) ** 2)
673
      )
674
      rel_rms_3 = (
675
          (torch.sigmoid(rel_holding_cost) ** 3)
676
          + (torch.sigmoid(rel_var_o) ** 3)
677
          + (torch.sigmoid(rel_achieved_service_level) ** 3)
678
      )
679
      rel_rms_5 = (
680
          (torch.sigmoid(rel_holding_cost) ** 5)
681
          + (torch.sigmoid(rel_var_o) ** 5)
682
          + (torch.sigmoid(rel_achieved_service_level) ** 5)
683
      )
684
      rel_rms_logsumexp = torch.logsumexp(
685
          torch.cat(
686
              [
687
                  torch.sigmoid(rel_holding_cost).unsqueeze(0),
688
                  torch.sigmoid(rel_var_o).unsqueeze(0),
689
                  torch.sigmoid(rel_achieved_service_level).unsqueeze(0),
690
              ],
691
              dim=0,
692
          ),
693
          dim=0,
694
      )
695

696
      rel_rms_stockout_2 = (
697
          (torch.sigmoid(rel_holding_cost) ** 2)
698
          + (torch.sigmoid(rel_var_o) ** 2)
699
          + (torch.sigmoid(rel_stockout_cost) ** 2)
700
      )
701
      rel_rms_stockout_3 = (
702
          (torch.sigmoid(rel_holding_cost) ** 3)
703
          + (torch.sigmoid(rel_var_o) ** 3)
704
          + (torch.sigmoid(rel_stockout_cost) ** 3)
705
      )
706
      rel_rms_stockout_5 = (
707
          (torch.sigmoid(rel_holding_cost) ** 5)
708
          + (torch.sigmoid(rel_var_o) ** 5)
709
          + (torch.sigmoid(rel_stockout_cost) ** 5)
710
      )
711

712
      inventory_metrics['scaled_rms'] = scaled_rms
713
      inventory_metrics['rel_rms_avg'] = rel_rms_avg
714
      inventory_metrics['rel_rms_2'] = rel_rms_2
715
      inventory_metrics['rel_rms_3'] = rel_rms_3
716
      inventory_metrics['rel_rms_5'] = rel_rms_5
717
      inventory_metrics['rel_rms_logsumexp'] = rel_rms_logsumexp
718
      inventory_metrics['rel_rms_stockout_2'] = rel_rms_stockout_2
719
      inventory_metrics['rel_rms_stockout_3'] = rel_rms_stockout_3
720
      inventory_metrics['rel_rms_stockout_5'] = rel_rms_stockout_5
721

722
    return inventory_metrics
723

724
  def compute_all_metrics(
725
      self,
726
      preds,
727
      actual_batch,
728
      target_service_level,
729
      unit_holding_cost,
730
      unit_stockout_cost,
731
      unit_var_o_cost,
732
      series_mean,
733
      quantile_loss,
734
      naive_model,
735
      scale_by_naive_model=False,
736
      rolling_eval=False,
737
      min0=False,
738
  ):
739
    """Given predictions, computes all metrics of interest.
740

741
    Args:
742
      preds: predicted values
743
      actual_batch: batch with actual values
744
      target_service_level: service level to use for safety stock calculation
745
      unit_holding_cost: cost per unit held
746
      unit_stockout_cost: cost per unit stockout
747
      unit_var_o_cost: cost per unit order variance
748
      series_mean:  whether to take mean across series
749
      quantile_loss: quantile loss objective, if relevant
750
      naive_model: baseline model
751
      scale_by_naive_model: whether to scale performance by baseline model
752
      rolling_eval: whether evaluation is roll-forward
753
      min0: whether to cut off predictions at 0 as the minimum (e.g. since
754
        negative demand is impossible)
755

756
    Returns:
757
    """
758
    actual_batch = self._extract_target_dims(actual_batch)
759

760
    all_metrics = {}
761

762
    _, T, _, _ = preds.shape  # N x T x L x D
763

764
    immediate_series_mean = False
765

766
    # compute naive model metrics
767
    naive_all_metrics = {}
768
    with torch.no_grad():
769
      naive_preds = naive_model(actual_batch, in_eval=True)
770
      naive_preds = naive_preds[:, :T, :, :]
771
      naive_inventory_metrics = self.compute_inventory_metrics(
772
          naive_preds,
773
          actual_batch,
774
          target_service_level=target_service_level,
775
          unit_holding_cost=unit_holding_cost,
776
          unit_stockout_cost=unit_stockout_cost,
777
          unit_var_o_cost=unit_var_o_cost,
778
          series_mean=immediate_series_mean,
779
      )
780
      naive_forecasting_metrics = self.compute_forecasting_metrics(
781
          naive_preds,
782
          actual_batch,
783
          series_mean=immediate_series_mean,
784
          rolling_eval=rolling_eval,
785
      )
786
      naive_all_metrics.update(naive_inventory_metrics)
787
      naive_all_metrics.update(naive_forecasting_metrics)
788

789
    # compute inventory metrics
790
    inventory_metrics = self.compute_inventory_metrics(
791
        preds,
792
        actual_batch,
793
        target_service_level=target_service_level,
794
        unit_holding_cost=unit_holding_cost,
795
        unit_stockout_cost=unit_stockout_cost,
796
        unit_var_o_cost=unit_var_o_cost,
797
        series_mean=immediate_series_mean,
798
        quantile_loss=quantile_loss,
799
        naive_metrics=naive_all_metrics,
800
    )
801

802
    for metric_name, metric_val in inventory_metrics.items():
803
      if metric_name == 'inventory_values':
804
        all_metrics[metric_name] = metric_val
805
        continue
806
      if scale_by_naive_model:
807
        metric_val = metric_val / (naive_all_metrics[metric_name] + 1e-5)
808
      if (not immediate_series_mean) and series_mean:
809
        metric_val = metric_val.mean()
810
      all_metrics[metric_name] = metric_val
811

812
    # compute forecasting metrics
813
    forecasting_metrics = self.compute_forecasting_metrics(
814
        preds,
815
        actual_batch,
816
        series_mean=immediate_series_mean,
817
        rolling_eval=rolling_eval,
818
    )
819

820
    for metric_name, metric_val in forecasting_metrics.items():
821
      if scale_by_naive_model and metric_name != 'mpe':
822
        metric_val = metric_val / (naive_all_metrics[metric_name] + 1e-5)
823
      if (not immediate_series_mean) and series_mean:
824
        metric_val = metric_val.mean()
825
      all_metrics[metric_name] = metric_val
826

827
    # add quantile loss
828
    if quantile_loss:
829
      if self.scale01:
830
        preds = self._rescale(
831
            preds, actual_batch['x_scale'], actual_batch['x_offset']
832
        )
833
        targets = self._rescale(
834
            actual_batch['eval_targets'],
835
            actual_batch['x_scale'],
836
            actual_batch['x_offset'],
837
        )
838
        if min0:
839
          preds = torch.nn.functional.relu(preds)
840
      forecast_horizon_lengths, time_lengths = self._get_lengths_from_time_mask(
841
          actual_batch['eval_target_times_mask']
842
      )
843
      qloss = quantile_loss(
844
          preds, targets, forecast_horizon_lengths, time_lengths
845
      )
846
      if series_mean:
847
        qloss = qloss.mean()
848
      all_metrics['quantile_loss'] = qloss
849

850
    return all_metrics
851

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.