transformers
1115 строк · 41.6 Кб
1# coding=utf-8
2# Copyright 2023 IBM and HuggingFace Inc. team. All rights reserved.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15""" Testing suite for the PyTorch PatchTSMixer model. """
16
17import inspect
18import itertools
19import random
20import tempfile
21import unittest
22from typing import Dict, List, Optional, Tuple, Union
23
24import numpy as np
25from huggingface_hub import hf_hub_download
26from parameterized import parameterized
27
28from transformers import is_torch_available
29from transformers.models.auto import get_values
30from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
31
32from ...test_configuration_common import ConfigTester
33from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
34from ...test_pipeline_mixin import PipelineTesterMixin
35
36
37TOLERANCE = 1e-4
38
39if is_torch_available():
40import torch
41
42from transformers import (
43MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING,
44MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING,
45PatchTSMixerConfig,
46PatchTSMixerForPrediction,
47PatchTSMixerForPretraining,
48PatchTSMixerForRegression,
49PatchTSMixerForTimeSeriesClassification,
50PatchTSMixerModel,
51)
52from transformers.models.patchtsmixer.modeling_patchtsmixer import (
53PatchTSMixerEncoder,
54PatchTSMixerForPredictionHead,
55PatchTSMixerForPredictionOutput,
56PatchTSMixerForRegressionOutput,
57PatchTSMixerForTimeSeriesClassificationOutput,
58PatchTSMixerLinearHead,
59PatchTSMixerPretrainHead,
60)
61
62
63@require_torch
64class PatchTSMixerModelTester:
65def __init__(
66self,
67context_length: int = 32,
68patch_length: int = 8,
69num_input_channels: int = 3,
70patch_stride: int = 8,
71# d_model: int = 128,
72hidden_size: int = 8,
73# num_layers: int = 8,
74num_hidden_layers: int = 2,
75expansion_factor: int = 2,
76dropout: float = 0.5,
77mode: str = "common_channel",
78gated_attn: bool = True,
79norm_mlp="LayerNorm",
80swin_hier: int = 0,
81# masking related
82mask_type: str = "forecast",
83random_mask_ratio=0.5,
84mask_patches: list = [2, 3],
85forecast_mask_ratios: list = [1, 1],
86mask_value=0,
87masked_loss: bool = False,
88mask_mode: str = "mask_before_encoder",
89channel_consistent_masking: bool = True,
90scaling: Optional[Union[str, bool]] = "std",
91# Head related
92head_dropout: float = 0.2,
93# forecast related
94prediction_length: int = 16,
95out_channels: int = None,
96# Classification/regression related
97# num_labels: int = 3,
98num_targets: int = 3,
99output_range: list = None,
100head_aggregation: str = None,
101# Trainer related
102batch_size=13,
103is_training=True,
104seed_number=42,
105post_init=True,
106num_parallel_samples=4,
107):
108self.num_input_channels = num_input_channels
109self.context_length = context_length
110self.patch_length = patch_length
111self.patch_stride = patch_stride
112# self.d_model = d_model
113self.hidden_size = hidden_size
114self.expansion_factor = expansion_factor
115# self.num_layers = num_layers
116self.num_hidden_layers = num_hidden_layers
117self.dropout = dropout
118self.mode = mode
119self.gated_attn = gated_attn
120self.norm_mlp = norm_mlp
121self.swin_hier = swin_hier
122self.scaling = scaling
123self.head_dropout = head_dropout
124# masking related
125self.mask_type = mask_type
126self.random_mask_ratio = random_mask_ratio
127self.mask_patches = mask_patches
128self.forecast_mask_ratios = forecast_mask_ratios
129self.mask_value = mask_value
130self.channel_consistent_masking = channel_consistent_masking
131self.mask_mode = mask_mode
132self.masked_loss = masked_loss
133# patching related
134self.patch_last = True
135# forecast related
136self.prediction_length = prediction_length
137self.out_channels = out_channels
138# classification/regression related
139# self.num_labels = num_labels
140self.num_targets = num_targets
141self.output_range = output_range
142self.head_aggregation = head_aggregation
143# Trainer related
144self.batch_size = batch_size
145self.is_training = is_training
146self.seed_number = seed_number
147self.post_init = post_init
148self.num_parallel_samples = num_parallel_samples
149
150def get_config(self):
151config_ = PatchTSMixerConfig(
152num_input_channels=self.num_input_channels,
153context_length=self.context_length,
154patch_length=self.patch_length,
155patch_stride=self.patch_stride,
156# d_model = self.d_model,
157d_model=self.hidden_size,
158expansion_factor=self.expansion_factor,
159# num_layers = self.num_layers,
160num_layers=self.num_hidden_layers,
161dropout=self.dropout,
162mode=self.mode,
163gated_attn=self.gated_attn,
164norm_mlp=self.norm_mlp,
165swin_hier=self.swin_hier,
166scaling=self.scaling,
167head_dropout=self.head_dropout,
168mask_type=self.mask_type,
169random_mask_ratio=self.random_mask_ratio,
170mask_patches=self.mask_patches,
171forecast_mask_ratios=self.forecast_mask_ratios,
172mask_value=self.mask_value,
173channel_consistent_masking=self.channel_consistent_masking,
174mask_mode=self.mask_mode,
175masked_loss=self.masked_loss,
176prediction_length=self.prediction_length,
177out_channels=self.out_channels,
178# num_labels=self.num_labels,
179num_targets=self.num_targets,
180output_range=self.output_range,
181head_aggregation=self.head_aggregation,
182post_init=self.post_init,
183)
184self.num_patches = config_.num_patches
185return config_
186
187def prepare_patchtsmixer_inputs_dict(self, config):
188_past_length = config.context_length
189# bs, n_vars, num_patch, patch_length
190
191# [bs x context_length x n_vars]
192past_values = floats_tensor([self.batch_size, _past_length, self.num_input_channels])
193
194inputs_dict = {
195"past_values": past_values,
196}
197return inputs_dict
198
199def prepare_config_and_inputs(self):
200config = self.get_config()
201inputs_dict = self.prepare_patchtsmixer_inputs_dict(config)
202return config, inputs_dict
203
204def prepare_config_and_inputs_for_common(self):
205config, inputs_dict = self.prepare_config_and_inputs()
206return config, inputs_dict
207
208
209@require_torch
210class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
211all_model_classes = (
212(
213PatchTSMixerModel,
214PatchTSMixerForPrediction,
215PatchTSMixerForPretraining,
216PatchTSMixerForTimeSeriesClassification,
217PatchTSMixerForRegression,
218)
219if is_torch_available()
220else ()
221)
222all_generative_model_classes = (
223(PatchTSMixerForPrediction, PatchTSMixerForPretraining) if is_torch_available() else ()
224)
225pipeline_model_mapping = {"feature-extraction": PatchTSMixerModel} if is_torch_available() else {}
226is_encoder_decoder = False
227test_pruning = False
228test_head_masking = False
229test_missing_keys = False
230test_torchscript = False
231test_inputs_embeds = False
232test_model_common_attributes = False
233
234test_resize_embeddings = True
235test_resize_position_embeddings = False
236test_mismatched_shapes = True
237test_model_parallel = False
238has_attentions = False
239
240def setUp(self):
241self.model_tester = PatchTSMixerModelTester()
242self.config_tester = ConfigTester(
243self,
244config_class=PatchTSMixerConfig,
245has_text_modality=False,
246prediction_length=self.model_tester.prediction_length,
247common_properties=["hidden_size", "expansion_factor", "num_hidden_layers"],
248)
249
250def test_config(self):
251self.config_tester.run_common_tests()
252
253def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
254inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
255
256if model_class == PatchTSMixerForPrediction:
257rng = random.Random(self.model_tester.seed_number)
258labels = floats_tensor(
259[
260self.model_tester.batch_size,
261self.model_tester.prediction_length,
262self.model_tester.num_input_channels,
263],
264rng=rng,
265)
266inputs_dict["future_values"] = labels
267elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
268rng = random.Random(self.model_tester.seed_number)
269labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
270inputs_dict["target_values"] = labels
271elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING):
272rng = random.Random(self.model_tester.seed_number)
273labels = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng)
274inputs_dict["target_values"] = labels
275
276inputs_dict["output_hidden_states"] = True
277return inputs_dict
278
279def test_save_load_strict(self):
280config, _ = self.model_tester.prepare_config_and_inputs()
281for model_class in self.all_model_classes:
282model = model_class(config)
283
284with tempfile.TemporaryDirectory() as tmpdirname:
285model.save_pretrained(tmpdirname)
286model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
287self.assertEqual(info["missing_keys"], [])
288
289def test_hidden_states_output(self):
290def check_hidden_states_output(inputs_dict, config, model_class):
291model = model_class(config)
292model.to(torch_device)
293model.eval()
294
295with torch.no_grad():
296outputs = model(**self._prepare_for_class(inputs_dict, model_class))
297
298hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
299
300expected_num_layers = getattr(
301self.model_tester,
302"expected_num_hidden_layers",
303self.model_tester.num_hidden_layers,
304)
305self.assertEqual(len(hidden_states), expected_num_layers)
306
307expected_hidden_size = self.model_tester.hidden_size
308self.assertEqual(hidden_states[0].shape[-1], expected_hidden_size)
309
310num_patch = self.model_tester.num_patches
311self.assertListEqual(
312list(hidden_states[0].shape[-2:]),
313[num_patch, self.model_tester.hidden_size],
314)
315
316config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
317
318for model_class in self.all_model_classes:
319check_hidden_states_output(inputs_dict, config, model_class)
320
321@unittest.skip("No tokens embeddings")
322def test_resize_tokens_embeddings(self):
323pass
324
325def test_model_outputs_equivalence(self):
326config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
327
328def set_nan_tensor_to_zero(t):
329t[t != t] = 0
330return t
331
332def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
333with torch.no_grad():
334tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
335output_ = model(**dict_inputs, return_dict=True, **additional_kwargs)
336attributes_ = vars(output_)
337dict_output = tuple(attributes_.values())
338
339def recursive_check(tuple_object, dict_object):
340if isinstance(tuple_object, (List, Tuple)):
341for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
342recursive_check(tuple_iterable_value, dict_iterable_value)
343elif isinstance(tuple_object, Dict):
344for tuple_iterable_value, dict_iterable_value in zip(
345tuple_object.values(), dict_object.values()
346):
347recursive_check(tuple_iterable_value, dict_iterable_value)
348elif tuple_object is None:
349return
350else:
351self.assertTrue(
352torch.allclose(
353set_nan_tensor_to_zero(tuple_object),
354set_nan_tensor_to_zero(dict_object),
355atol=1e-5,
356),
357msg=(
358"Tuple and dict output are not equal. Difference:"
359f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
360f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
361f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
362),
363)
364
365recursive_check(tuple_output, dict_output)
366
367for model_class in self.all_model_classes:
368print(model_class)
369model = model_class(config)
370model.to(torch_device)
371model.eval()
372
373tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
374dict_inputs = self._prepare_for_class(inputs_dict, model_class)
375
376check_equivalence(model, tuple_inputs, dict_inputs)
377
378tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
379dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
380check_equivalence(model, tuple_inputs, dict_inputs)
381
382tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
383dict_inputs = self._prepare_for_class(inputs_dict, model_class)
384tuple_inputs.update({"output_hidden_states": False})
385dict_inputs.update({"output_hidden_states": False})
386check_equivalence(model, tuple_inputs, dict_inputs)
387
388tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
389dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
390tuple_inputs.update({"output_hidden_states": False})
391dict_inputs.update({"output_hidden_states": False})
392check_equivalence(
393model,
394tuple_inputs,
395dict_inputs,
396)
397
398def test_model_main_input_name(self):
399model_signature = inspect.signature(getattr(PatchTSMixerModel, "forward"))
400# The main input is the name of the argument after `self`
401observed_main_input_name = list(model_signature.parameters.keys())[1]
402self.assertEqual(PatchTSMixerModel.main_input_name, observed_main_input_name)
403
404def test_forward_signature(self):
405config, _ = self.model_tester.prepare_config_and_inputs_for_common()
406
407for model_class in self.all_model_classes:
408model = model_class(config)
409signature = inspect.signature(model.forward)
410# signature.parameters is an OrderedDict => so arg_names order is deterministic
411arg_names = [*signature.parameters.keys()]
412
413if model_class == PatchTSMixerForPretraining:
414expected_arg_names = [
415"past_values",
416"observed_mask",
417"output_hidden_states",
418"return_loss",
419]
420elif model_class == PatchTSMixerModel:
421expected_arg_names = [
422"past_values",
423"observed_mask",
424"output_hidden_states",
425]
426elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values(
427MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING
428):
429expected_arg_names = [
430"past_values",
431"target_values",
432"output_hidden_states",
433"return_loss",
434]
435else:
436# PatchTSMixerForPrediction
437expected_arg_names = [
438"past_values",
439"observed_mask",
440"future_values",
441"output_hidden_states",
442"return_loss",
443]
444
445self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
446
447@is_flaky()
448def test_retain_grad_hidden_states_attentions(self):
449super().test_retain_grad_hidden_states_attentions()
450
451
452def prepare_batch(repo_id="ibm/patchtsmixer-etth1-test-data", file="pretrain_batch.pt"):
453# TODO: Make repo public
454file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
455batch = torch.load(file, map_location=torch_device)
456return batch
457
458
459@require_torch
460@slow
461class PatchTSMixerModelIntegrationTests(unittest.TestCase):
462def test_pretrain_head(self):
463model = PatchTSMixerForPretraining.from_pretrained("ibm/patchtsmixer-etth1-pretrain").to(torch_device)
464batch = prepare_batch()
465
466torch.manual_seed(0)
467with torch.no_grad():
468output = model(past_values=batch["past_values"].to(torch_device)).prediction_outputs
469num_patch = (
470max(model.config.context_length, model.config.patch_length) - model.config.patch_length
471) // model.config.patch_stride + 1
472expected_shape = torch.Size(
473[
47464,
475model.config.num_input_channels,
476num_patch,
477model.config.patch_length,
478]
479)
480self.assertEqual(output.shape, expected_shape)
481
482expected_slice = torch.tensor([[[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],]],device=torch_device) # fmt: skip
483self.assertTrue(torch.allclose(output[0, :7, :1, :1], expected_slice, atol=TOLERANCE))
484
485def test_forecasting_head(self):
486model = PatchTSMixerForPrediction.from_pretrained("ibm/patchtsmixer-etth1-forecasting").to(torch_device)
487batch = prepare_batch(file="forecast_batch.pt")
488
489model.eval()
490torch.manual_seed(0)
491with torch.no_grad():
492output = model(
493past_values=batch["past_values"].to(torch_device),
494future_values=batch["future_values"].to(torch_device),
495).prediction_outputs
496
497expected_shape = torch.Size([64, model.config.prediction_length, model.config.num_input_channels])
498self.assertEqual(output.shape, expected_shape)
499
500expected_slice = torch.tensor(
501[[0.2471, 0.5036, 0.3596, 0.5401, -0.0985, 0.3423, -0.8439]],
502device=torch_device,
503)
504self.assertTrue(torch.allclose(output[0, :1, :7], expected_slice, atol=TOLERANCE))
505
506def test_prediction_generation(self):
507model = PatchTSMixerForPrediction.from_pretrained("ibm/patchtsmixer-etth1-generate").to(torch_device)
508batch = prepare_batch(file="forecast_batch.pt")
509print(batch["past_values"])
510
511torch.manual_seed(0)
512model.eval()
513with torch.no_grad():
514outputs = model.generate(past_values=batch["past_values"].to(torch_device))
515expected_shape = torch.Size((64, 1, model.config.prediction_length, model.config.num_input_channels))
516
517self.assertEqual(outputs.sequences.shape, expected_shape)
518
519expected_slice = torch.tensor(
520[[0.4308, -0.4731, 1.3512, -0.1038, -0.4655, 1.1279, -0.7179]],
521device=torch_device,
522)
523
524mean_prediction = outputs.sequences.mean(dim=1)
525
526self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, atol=TOLERANCE))
527
528
529@require_torch
530class PatchTSMixerFunctionalTests(unittest.TestCase):
531@classmethod
532def setUpClass(cls):
533"""Setup method: Called once before test-cases execution"""
534cls.params = {}
535cls.params.update(
536context_length=32,
537patch_length=8,
538num_input_channels=3,
539patch_stride=8,
540d_model=4,
541expansion_factor=2,
542num_layers=3,
543dropout=0.2,
544mode="common_channel", # common_channel, mix_channel
545gated_attn=True,
546norm_mlp="LayerNorm",
547mask_type="random",
548random_mask_ratio=0.5,
549mask_patches=[2, 3],
550forecast_mask_ratios=[1, 1],
551mask_value=0,
552masked_loss=True,
553channel_consistent_masking=True,
554head_dropout=0.2,
555prediction_length=64,
556out_channels=None,
557# num_labels=3,
558num_targets=3,
559output_range=None,
560head_aggregation=None,
561scaling="std",
562use_positional_encoding=False,
563positional_encoding="sincos",
564self_attn=False,
565self_attn_heads=1,
566num_parallel_samples=4,
567)
568
569cls.num_patches = (
570max(cls.params["context_length"], cls.params["patch_length"]) - cls.params["patch_length"]
571) // cls.params["patch_stride"] + 1
572
573# batch_size = 32
574batch_size = 2
575
576int(cls.params["prediction_length"] / cls.params["patch_length"])
577
578cls.data = torch.rand(
579batch_size,
580cls.params["context_length"],
581cls.params["num_input_channels"],
582)
583
584cls.enc_data = torch.rand(
585batch_size,
586cls.params["num_input_channels"],
587cls.num_patches,
588cls.params["patch_length"],
589)
590
591cls.enc_output = torch.rand(
592batch_size,
593cls.params["num_input_channels"],
594cls.num_patches,
595cls.params["d_model"],
596)
597
598cls.flat_enc_output = torch.rand(
599batch_size,
600cls.num_patches,
601cls.params["d_model"],
602)
603
604cls.correct_pred_output = torch.rand(
605batch_size,
606cls.params["prediction_length"],
607cls.params["num_input_channels"],
608)
609cls.correct_regression_output = torch.rand(batch_size, cls.params["num_targets"])
610
611cls.correct_pretrain_output = torch.rand(
612batch_size,
613cls.params["num_input_channels"],
614cls.num_patches,
615cls.params["patch_length"],
616)
617
618cls.correct_forecast_output = torch.rand(
619batch_size,
620cls.params["prediction_length"],
621cls.params["num_input_channels"],
622)
623
624cls.correct_sel_forecast_output = torch.rand(batch_size, cls.params["prediction_length"], 2)
625
626cls.correct_classification_output = torch.rand(
627batch_size,
628cls.params["num_targets"],
629)
630
631cls.correct_classification_classes = torch.randint(0, cls.params["num_targets"], (batch_size,))
632
633def test_patchtsmixer_encoder(self):
634config = PatchTSMixerConfig(**self.__class__.params)
635enc = PatchTSMixerEncoder(config)
636output = enc(self.__class__.enc_data)
637self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
638
639def test_patchmodel(self):
640config = PatchTSMixerConfig(**self.__class__.params)
641mdl = PatchTSMixerModel(config)
642output = mdl(self.__class__.data)
643self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
644self.assertEqual(output.patch_input.shape, self.__class__.enc_data.shape)
645
646def test_pretrainhead(self):
647config = PatchTSMixerConfig(**self.__class__.params)
648head = PatchTSMixerPretrainHead(
649config=config,
650)
651output = head(self.__class__.enc_output)
652
653self.assertEqual(output.shape, self.__class__.correct_pretrain_output.shape)
654
655def test_pretrain_full(self):
656config = PatchTSMixerConfig(**self.__class__.params)
657mdl = PatchTSMixerForPretraining(config)
658output = mdl(self.__class__.data)
659self.assertEqual(
660output.prediction_outputs.shape,
661self.__class__.correct_pretrain_output.shape,
662)
663self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
664self.assertEqual(output.loss.item() < np.inf, True)
665
666def test_pretrain_full_with_return_dict(self):
667config = PatchTSMixerConfig(**self.__class__.params)
668mdl = PatchTSMixerForPretraining(config)
669output = mdl(self.__class__.data, return_dict=False)
670self.assertEqual(output[1].shape, self.__class__.correct_pretrain_output.shape)
671self.assertEqual(output[2].shape, self.__class__.enc_output.shape)
672self.assertEqual(output[0].item() < np.inf, True)
673
674def test_forecast_head(self):
675config = PatchTSMixerConfig(**self.__class__.params)
676head = PatchTSMixerForPredictionHead(
677config=config,
678)
679# output = head(self.__class__.enc_output, raw_data = self.__class__.correct_pretrain_output)
680output = head(self.__class__.enc_output)
681
682self.assertEqual(output.shape, self.__class__.correct_forecast_output.shape)
683
684def check_module(
685self,
686task,
687params=None,
688output_hidden_states=True,
689):
690config = PatchTSMixerConfig(**params)
691if task == "forecast":
692mdl = PatchTSMixerForPrediction(config)
693target_input = self.__class__.correct_forecast_output
694if config.prediction_channel_indices is not None:
695target_output = self.__class__.correct_sel_forecast_output
696else:
697target_output = target_input
698ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1, -1)
699ground_truth_arg = "future_values"
700output_predictions_arg = "prediction_outputs"
701elif task == "classification":
702mdl = PatchTSMixerForTimeSeriesClassification(config)
703target_input = self.__class__.correct_classification_classes
704target_output = self.__class__.correct_classification_output
705ground_truth_arg = "target_values"
706output_predictions_arg = "prediction_outputs"
707elif task == "regression":
708mdl = PatchTSMixerForRegression(config)
709target_input = self.__class__.correct_regression_output
710target_output = self.__class__.correct_regression_output
711ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1)
712ground_truth_arg = "target_values"
713output_predictions_arg = "regression_outputs"
714elif task == "pretrain":
715mdl = PatchTSMixerForPretraining(config)
716target_input = None
717target_output = self.__class__.correct_pretrain_output
718ground_truth_arg = None
719output_predictions_arg = "prediction_outputs"
720else:
721print("invalid task")
722
723enc_output = self.__class__.enc_output
724
725if target_input is None:
726output = mdl(self.__class__.data, output_hidden_states=output_hidden_states)
727else:
728output = mdl(
729self.__class__.data,
730**{
731ground_truth_arg: target_input,
732"output_hidden_states": output_hidden_states,
733},
734)
735
736prediction_outputs = getattr(output, output_predictions_arg)
737if isinstance(prediction_outputs, tuple):
738for t in prediction_outputs:
739self.assertEqual(t.shape, target_output.shape)
740else:
741self.assertEqual(prediction_outputs.shape, target_output.shape)
742
743self.assertEqual(output.last_hidden_state.shape, enc_output.shape)
744
745if output_hidden_states is True:
746self.assertEqual(len(output.hidden_states), params["num_layers"])
747
748else:
749self.assertEqual(output.hidden_states, None)
750
751self.assertEqual(output.loss.item() < np.inf, True)
752
753if config.loss == "nll" and task in ["forecast", "regression"]:
754samples = mdl.generate(self.__class__.data)
755self.assertEqual(samples.sequences.shape, ref_samples.shape)
756
757@parameterized.expand(
758list(
759itertools.product(
760["common_channel", "mix_channel"],
761[True, False],
762[True, False, "mean", "std"],
763[True, False],
764[None, [0, 2]],
765["mse", "nll"],
766)
767)
768)
769def test_forecast(self, mode, self_attn, scaling, gated_attn, prediction_channel_indices, loss):
770params = self.__class__.params.copy()
771params.update(
772mode=mode,
773self_attn=self_attn,
774scaling=scaling,
775prediction_channel_indices=prediction_channel_indices,
776gated_attn=gated_attn,
777loss=loss,
778)
779
780self.check_module(task="forecast", params=params)
781
782@parameterized.expand(
783list(
784itertools.product(
785["common_channel", "mix_channel"],
786[True, False],
787[True, False, "mean", "std"],
788[True, False],
789["max_pool", "avg_pool"],
790)
791)
792)
793def test_classification(self, mode, self_attn, scaling, gated_attn, head_aggregation):
794params = self.__class__.params.copy()
795params.update(
796mode=mode,
797self_attn=self_attn,
798scaling=scaling,
799head_aggregation=head_aggregation,
800gated_attn=gated_attn,
801)
802
803self.check_module(task="classification", params=params)
804
805@parameterized.expand(
806list(
807itertools.product(
808["common_channel", "mix_channel"],
809[True, False],
810[True, False, "mean", "std"],
811[True, False],
812["max_pool", "avg_pool"],
813["mse", "nll"],
814)
815)
816)
817def test_regression(self, mode, self_attn, scaling, gated_attn, head_aggregation, loss):
818params = self.__class__.params.copy()
819params.update(
820mode=mode,
821self_attn=self_attn,
822scaling=scaling,
823head_aggregation=head_aggregation,
824gated_attn=gated_attn,
825loss=loss,
826)
827
828self.check_module(task="regression", params=params)
829
830@parameterized.expand(
831list(
832itertools.product(
833["common_channel", "mix_channel"],
834[True, False],
835[True, False, "mean", "std"],
836[True, False],
837["random", "forecast"],
838[True, False],
839[True, False],
840)
841)
842)
843def test_pretrain(
844self,
845mode,
846self_attn,
847scaling,
848gated_attn,
849mask_type,
850masked_loss,
851channel_consistent_masking,
852):
853params = self.__class__.params.copy()
854params.update(
855mode=mode,
856self_attn=self_attn,
857scaling=scaling,
858gated_attn=gated_attn,
859mask_type=mask_type,
860masked_loss=masked_loss,
861channel_consistent_masking=channel_consistent_masking,
862)
863
864self.check_module(task="pretrain", params=params)
865
866def forecast_full_module(self, params=None, output_hidden_states=False, return_dict=None):
867config = PatchTSMixerConfig(**params)
868mdl = PatchTSMixerForPrediction(config)
869
870target_val = self.__class__.correct_forecast_output
871
872if config.prediction_channel_indices is not None:
873target_val = self.__class__.correct_sel_forecast_output
874
875enc_output = self.__class__.enc_output
876
877output = mdl(
878self.__class__.data,
879future_values=self.__class__.correct_forecast_output,
880output_hidden_states=output_hidden_states,
881return_dict=return_dict,
882)
883
884if isinstance(output, tuple):
885output = PatchTSMixerForPredictionOutput(*output)
886
887if config.loss == "mse":
888self.assertEqual(output.prediction_outputs.shape, target_val.shape)
889
890self.assertEqual(output.last_hidden_state.shape, enc_output.shape)
891
892if output_hidden_states is True:
893self.assertEqual(len(output.hidden_states), params["num_layers"])
894
895else:
896self.assertEqual(output.hidden_states, None)
897
898self.assertEqual(output.loss.item() < np.inf, True)
899
900if config.loss == "nll":
901samples = mdl.generate(self.__class__.data)
902ref_samples = target_val.unsqueeze(1).expand(-1, params["num_parallel_samples"], -1, -1)
903self.assertEqual(samples.sequences.shape, ref_samples.shape)
904
905def test_forecast_full(self):
906self.check_module(task="forecast", params=self.__class__.params, output_hidden_states=True)
907# self.forecast_full_module(self.__class__.params, output_hidden_states = True)
908
909def test_forecast_full_2(self):
910params = self.__class__.params.copy()
911params.update(
912mode="mix_channel",
913)
914self.forecast_full_module(params, output_hidden_states=True)
915
916def test_forecast_full_2_with_return_dict(self):
917params = self.__class__.params.copy()
918params.update(
919mode="mix_channel",
920)
921self.forecast_full_module(params, output_hidden_states=True, return_dict=False)
922
923def test_forecast_full_3(self):
924params = self.__class__.params.copy()
925params.update(
926mode="mix_channel",
927)
928self.forecast_full_module(params, output_hidden_states=True)
929
930def test_forecast_full_5(self):
931params = self.__class__.params.copy()
932params.update(
933self_attn=True,
934use_positional_encoding=True,
935positional_encoding="sincos",
936)
937self.forecast_full_module(params, output_hidden_states=True)
938
939def test_forecast_full_4(self):
940params = self.__class__.params.copy()
941params.update(
942mode="mix_channel",
943prediction_channel_indices=[0, 2],
944)
945self.forecast_full_module(params)
946
947def test_forecast_full_distributional(self):
948params = self.__class__.params.copy()
949params.update(
950mode="mix_channel",
951prediction_channel_indices=[0, 2],
952loss="nll",
953distribution_output="normal",
954)
955
956self.forecast_full_module(params)
957
958def test_forecast_full_distributional_2(self):
959params = self.__class__.params.copy()
960params.update(
961mode="mix_channel",
962prediction_channel_indices=[0, 2],
963loss="nll",
964# distribution_output = "normal",
965)
966self.forecast_full_module(params)
967
968def test_forecast_full_distributional_3(self):
969params = self.__class__.params.copy()
970params.update(
971mode="mix_channel",
972# prediction_channel_indices=[0, 2],
973loss="nll",
974distribution_output="normal",
975)
976self.forecast_full_module(params)
977
978def test_forecast_full_distributional_4(self):
979params = self.__class__.params.copy()
980params.update(
981mode="mix_channel",
982# prediction_channel_indices=[0, 2],
983loss="nll",
984distribution_output="normal",
985)
986self.forecast_full_module(params)
987
988def test_classification_head(self):
989config = PatchTSMixerConfig(**self.__class__.params)
990head = PatchTSMixerLinearHead(
991config=config,
992)
993# output = head(self.__class__.enc_output, raw_data = self.__class__.correct_pretrain_output)
994output = head(self.__class__.enc_output)
995
996self.assertEqual(output.shape, self.__class__.correct_classification_output.shape)
997
998def test_classification_full(self):
999config = PatchTSMixerConfig(**self.__class__.params)
1000mdl = PatchTSMixerForTimeSeriesClassification(config)
1001output = mdl(
1002self.__class__.data,
1003target_values=self.__class__.correct_classification_classes,
1004)
1005self.assertEqual(
1006output.prediction_outputs.shape,
1007self.__class__.correct_classification_output.shape,
1008)
1009self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1010self.assertEqual(output.loss.item() < np.inf, True)
1011
1012def test_classification_full_with_return_dict(self):
1013config = PatchTSMixerConfig(**self.__class__.params)
1014mdl = PatchTSMixerForTimeSeriesClassification(config)
1015output = mdl(
1016self.__class__.data,
1017target_values=self.__class__.correct_classification_classes,
1018return_dict=False,
1019)
1020if isinstance(output, tuple):
1021output = PatchTSMixerForTimeSeriesClassificationOutput(*output)
1022self.assertEqual(
1023output.prediction_outputs.shape,
1024self.__class__.correct_classification_output.shape,
1025)
1026self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1027self.assertEqual(output.loss.item() < np.inf, True)
1028
1029def test_regression_head(self):
1030config = PatchTSMixerConfig(**self.__class__.params)
1031head = PatchTSMixerLinearHead(
1032config=config,
1033)
1034output = head(self.__class__.enc_output)
1035self.assertEqual(output.shape, self.__class__.correct_regression_output.shape)
1036
1037def test_regression_full(self):
1038config = PatchTSMixerConfig(**self.__class__.params)
1039mdl = PatchTSMixerForRegression(config)
1040output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
1041self.assertEqual(
1042output.regression_outputs.shape,
1043self.__class__.correct_regression_output.shape,
1044)
1045self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1046self.assertEqual(output.loss.item() < np.inf, True)
1047
1048def test_regression_full_with_return_dict(self):
1049config = PatchTSMixerConfig(**self.__class__.params)
1050mdl = PatchTSMixerForRegression(config)
1051output = mdl(
1052self.__class__.data,
1053target_values=self.__class__.correct_regression_output,
1054return_dict=False,
1055)
1056if isinstance(output, tuple):
1057output = PatchTSMixerForRegressionOutput(*output)
1058self.assertEqual(
1059output.regression_outputs.shape,
1060self.__class__.correct_regression_output.shape,
1061)
1062self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1063self.assertEqual(output.loss.item() < np.inf, True)
1064
1065def test_regression_full_distribute(self):
1066params = self.__class__.params.copy()
1067params.update(loss="nll", distribution_output="normal")
1068
1069config = PatchTSMixerConfig(**params)
1070
1071mdl = PatchTSMixerForRegression(config)
1072output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
1073self.assertEqual(
1074output.regression_outputs[0].shape,
1075self.__class__.correct_regression_output.shape,
1076)
1077self.assertEqual(
1078output.regression_outputs[1].shape,
1079self.__class__.correct_regression_output.shape,
1080)
1081self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1082self.assertEqual(output.loss.item() < np.inf, True)
1083
1084if config.loss == "nll":
1085samples = mdl.generate(self.__class__.data)
1086ref_samples = self.__class__.correct_regression_output.unsqueeze(1).expand(
1087-1, params["num_parallel_samples"], -1
1088)
1089self.assertEqual(samples.sequences.shape, ref_samples.shape)
1090
1091def test_regression_full_distribute_2(self):
1092params = self.__class__.params.copy()
1093params.update(loss="nll", distribution_output="student_t")
1094
1095config = PatchTSMixerConfig(**params)
1096
1097mdl = PatchTSMixerForRegression(config)
1098output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
1099self.assertEqual(
1100output.regression_outputs[0].shape,
1101self.__class__.correct_regression_output.shape,
1102)
1103self.assertEqual(
1104output.regression_outputs[1].shape,
1105self.__class__.correct_regression_output.shape,
1106)
1107self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1108self.assertEqual(output.loss.item() < np.inf, True)
1109
1110if config.loss == "nll":
1111samples = mdl.generate(self.__class__.data)
1112ref_samples = self.__class__.correct_regression_output.unsqueeze(1).expand(
1113-1, params["num_parallel_samples"], -1
1114)
1115self.assertEqual(samples.sequences.shape, ref_samples.shape)
1116