transformers

Форк
0
/
test_modeling_patchtsmixer.py 
1115 строк · 41.6 Кб
1
# coding=utf-8
2
# Copyright 2023 IBM and HuggingFace Inc. team. All rights reserved.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15
""" Testing suite for the PyTorch PatchTSMixer model. """
16

17
import inspect
18
import itertools
19
import random
20
import tempfile
21
import unittest
22
from typing import Dict, List, Optional, Tuple, Union
23

24
import numpy as np
25
from huggingface_hub import hf_hub_download
26
from parameterized import parameterized
27

28
from transformers import is_torch_available
29
from transformers.models.auto import get_values
30
from transformers.testing_utils import is_flaky, require_torch, slow, torch_device
31

32
from ...test_configuration_common import ConfigTester
33
from ...test_modeling_common import ModelTesterMixin, floats_tensor, ids_tensor
34
from ...test_pipeline_mixin import PipelineTesterMixin
35

36

37
TOLERANCE = 1e-4
38

39
if is_torch_available():
40
    import torch
41

42
    from transformers import (
43
        MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING,
44
        MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING,
45
        PatchTSMixerConfig,
46
        PatchTSMixerForPrediction,
47
        PatchTSMixerForPretraining,
48
        PatchTSMixerForRegression,
49
        PatchTSMixerForTimeSeriesClassification,
50
        PatchTSMixerModel,
51
    )
52
    from transformers.models.patchtsmixer.modeling_patchtsmixer import (
53
        PatchTSMixerEncoder,
54
        PatchTSMixerForPredictionHead,
55
        PatchTSMixerForPredictionOutput,
56
        PatchTSMixerForRegressionOutput,
57
        PatchTSMixerForTimeSeriesClassificationOutput,
58
        PatchTSMixerLinearHead,
59
        PatchTSMixerPretrainHead,
60
    )
61

62

63
@require_torch
64
class PatchTSMixerModelTester:
65
    def __init__(
66
        self,
67
        context_length: int = 32,
68
        patch_length: int = 8,
69
        num_input_channels: int = 3,
70
        patch_stride: int = 8,
71
        # d_model: int = 128,
72
        hidden_size: int = 8,
73
        # num_layers: int = 8,
74
        num_hidden_layers: int = 2,
75
        expansion_factor: int = 2,
76
        dropout: float = 0.5,
77
        mode: str = "common_channel",
78
        gated_attn: bool = True,
79
        norm_mlp="LayerNorm",
80
        swin_hier: int = 0,
81
        # masking related
82
        mask_type: str = "forecast",
83
        random_mask_ratio=0.5,
84
        mask_patches: list = [2, 3],
85
        forecast_mask_ratios: list = [1, 1],
86
        mask_value=0,
87
        masked_loss: bool = False,
88
        mask_mode: str = "mask_before_encoder",
89
        channel_consistent_masking: bool = True,
90
        scaling: Optional[Union[str, bool]] = "std",
91
        # Head related
92
        head_dropout: float = 0.2,
93
        # forecast related
94
        prediction_length: int = 16,
95
        out_channels: int = None,
96
        # Classification/regression related
97
        # num_labels: int = 3,
98
        num_targets: int = 3,
99
        output_range: list = None,
100
        head_aggregation: str = None,
101
        # Trainer related
102
        batch_size=13,
103
        is_training=True,
104
        seed_number=42,
105
        post_init=True,
106
        num_parallel_samples=4,
107
    ):
108
        self.num_input_channels = num_input_channels
109
        self.context_length = context_length
110
        self.patch_length = patch_length
111
        self.patch_stride = patch_stride
112
        # self.d_model = d_model
113
        self.hidden_size = hidden_size
114
        self.expansion_factor = expansion_factor
115
        # self.num_layers = num_layers
116
        self.num_hidden_layers = num_hidden_layers
117
        self.dropout = dropout
118
        self.mode = mode
119
        self.gated_attn = gated_attn
120
        self.norm_mlp = norm_mlp
121
        self.swin_hier = swin_hier
122
        self.scaling = scaling
123
        self.head_dropout = head_dropout
124
        # masking related
125
        self.mask_type = mask_type
126
        self.random_mask_ratio = random_mask_ratio
127
        self.mask_patches = mask_patches
128
        self.forecast_mask_ratios = forecast_mask_ratios
129
        self.mask_value = mask_value
130
        self.channel_consistent_masking = channel_consistent_masking
131
        self.mask_mode = mask_mode
132
        self.masked_loss = masked_loss
133
        # patching related
134
        self.patch_last = True
135
        # forecast related
136
        self.prediction_length = prediction_length
137
        self.out_channels = out_channels
138
        # classification/regression related
139
        # self.num_labels = num_labels
140
        self.num_targets = num_targets
141
        self.output_range = output_range
142
        self.head_aggregation = head_aggregation
143
        # Trainer related
144
        self.batch_size = batch_size
145
        self.is_training = is_training
146
        self.seed_number = seed_number
147
        self.post_init = post_init
148
        self.num_parallel_samples = num_parallel_samples
149

150
    def get_config(self):
151
        config_ = PatchTSMixerConfig(
152
            num_input_channels=self.num_input_channels,
153
            context_length=self.context_length,
154
            patch_length=self.patch_length,
155
            patch_stride=self.patch_stride,
156
            # d_model = self.d_model,
157
            d_model=self.hidden_size,
158
            expansion_factor=self.expansion_factor,
159
            # num_layers = self.num_layers,
160
            num_layers=self.num_hidden_layers,
161
            dropout=self.dropout,
162
            mode=self.mode,
163
            gated_attn=self.gated_attn,
164
            norm_mlp=self.norm_mlp,
165
            swin_hier=self.swin_hier,
166
            scaling=self.scaling,
167
            head_dropout=self.head_dropout,
168
            mask_type=self.mask_type,
169
            random_mask_ratio=self.random_mask_ratio,
170
            mask_patches=self.mask_patches,
171
            forecast_mask_ratios=self.forecast_mask_ratios,
172
            mask_value=self.mask_value,
173
            channel_consistent_masking=self.channel_consistent_masking,
174
            mask_mode=self.mask_mode,
175
            masked_loss=self.masked_loss,
176
            prediction_length=self.prediction_length,
177
            out_channels=self.out_channels,
178
            # num_labels=self.num_labels,
179
            num_targets=self.num_targets,
180
            output_range=self.output_range,
181
            head_aggregation=self.head_aggregation,
182
            post_init=self.post_init,
183
        )
184
        self.num_patches = config_.num_patches
185
        return config_
186

187
    def prepare_patchtsmixer_inputs_dict(self, config):
188
        _past_length = config.context_length
189
        # bs, n_vars, num_patch, patch_length
190

191
        # [bs x context_length x n_vars]
192
        past_values = floats_tensor([self.batch_size, _past_length, self.num_input_channels])
193

194
        inputs_dict = {
195
            "past_values": past_values,
196
        }
197
        return inputs_dict
198

199
    def prepare_config_and_inputs(self):
200
        config = self.get_config()
201
        inputs_dict = self.prepare_patchtsmixer_inputs_dict(config)
202
        return config, inputs_dict
203

204
    def prepare_config_and_inputs_for_common(self):
205
        config, inputs_dict = self.prepare_config_and_inputs()
206
        return config, inputs_dict
207

208

209
@require_torch
210
class PatchTSMixerModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
211
    all_model_classes = (
212
        (
213
            PatchTSMixerModel,
214
            PatchTSMixerForPrediction,
215
            PatchTSMixerForPretraining,
216
            PatchTSMixerForTimeSeriesClassification,
217
            PatchTSMixerForRegression,
218
        )
219
        if is_torch_available()
220
        else ()
221
    )
222
    all_generative_model_classes = (
223
        (PatchTSMixerForPrediction, PatchTSMixerForPretraining) if is_torch_available() else ()
224
    )
225
    pipeline_model_mapping = {"feature-extraction": PatchTSMixerModel} if is_torch_available() else {}
226
    is_encoder_decoder = False
227
    test_pruning = False
228
    test_head_masking = False
229
    test_missing_keys = False
230
    test_torchscript = False
231
    test_inputs_embeds = False
232
    test_model_common_attributes = False
233

234
    test_resize_embeddings = True
235
    test_resize_position_embeddings = False
236
    test_mismatched_shapes = True
237
    test_model_parallel = False
238
    has_attentions = False
239

240
    def setUp(self):
241
        self.model_tester = PatchTSMixerModelTester()
242
        self.config_tester = ConfigTester(
243
            self,
244
            config_class=PatchTSMixerConfig,
245
            has_text_modality=False,
246
            prediction_length=self.model_tester.prediction_length,
247
            common_properties=["hidden_size", "expansion_factor", "num_hidden_layers"],
248
        )
249

250
    def test_config(self):
251
        self.config_tester.run_common_tests()
252

253
    def _prepare_for_class(self, inputs_dict, model_class, return_labels=False):
254
        inputs_dict = super()._prepare_for_class(inputs_dict, model_class, return_labels=return_labels)
255

256
        if model_class == PatchTSMixerForPrediction:
257
            rng = random.Random(self.model_tester.seed_number)
258
            labels = floats_tensor(
259
                [
260
                    self.model_tester.batch_size,
261
                    self.model_tester.prediction_length,
262
                    self.model_tester.num_input_channels,
263
                ],
264
                rng=rng,
265
            )
266
            inputs_dict["future_values"] = labels
267
        elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING):
268
            rng = random.Random(self.model_tester.seed_number)
269
            labels = ids_tensor([self.model_tester.batch_size], self.model_tester.num_targets, rng=rng)
270
            inputs_dict["target_values"] = labels
271
        elif model_class in get_values(MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING):
272
            rng = random.Random(self.model_tester.seed_number)
273
            labels = floats_tensor([self.model_tester.batch_size, self.model_tester.num_targets], rng=rng)
274
            inputs_dict["target_values"] = labels
275

276
        inputs_dict["output_hidden_states"] = True
277
        return inputs_dict
278

279
    def test_save_load_strict(self):
280
        config, _ = self.model_tester.prepare_config_and_inputs()
281
        for model_class in self.all_model_classes:
282
            model = model_class(config)
283

284
            with tempfile.TemporaryDirectory() as tmpdirname:
285
                model.save_pretrained(tmpdirname)
286
                model2, info = model_class.from_pretrained(tmpdirname, output_loading_info=True)
287
            self.assertEqual(info["missing_keys"], [])
288

289
    def test_hidden_states_output(self):
290
        def check_hidden_states_output(inputs_dict, config, model_class):
291
            model = model_class(config)
292
            model.to(torch_device)
293
            model.eval()
294

295
            with torch.no_grad():
296
                outputs = model(**self._prepare_for_class(inputs_dict, model_class))
297

298
            hidden_states = outputs.encoder_hidden_states if config.is_encoder_decoder else outputs.hidden_states
299

300
            expected_num_layers = getattr(
301
                self.model_tester,
302
                "expected_num_hidden_layers",
303
                self.model_tester.num_hidden_layers,
304
            )
305
            self.assertEqual(len(hidden_states), expected_num_layers)
306

307
            expected_hidden_size = self.model_tester.hidden_size
308
            self.assertEqual(hidden_states[0].shape[-1], expected_hidden_size)
309

310
            num_patch = self.model_tester.num_patches
311
            self.assertListEqual(
312
                list(hidden_states[0].shape[-2:]),
313
                [num_patch, self.model_tester.hidden_size],
314
            )
315

316
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
317

318
        for model_class in self.all_model_classes:
319
            check_hidden_states_output(inputs_dict, config, model_class)
320

321
    @unittest.skip("No tokens embeddings")
322
    def test_resize_tokens_embeddings(self):
323
        pass
324

325
    def test_model_outputs_equivalence(self):
326
        config, inputs_dict = self.model_tester.prepare_config_and_inputs_for_common()
327

328
        def set_nan_tensor_to_zero(t):
329
            t[t != t] = 0
330
            return t
331

332
        def check_equivalence(model, tuple_inputs, dict_inputs, additional_kwargs={}):
333
            with torch.no_grad():
334
                tuple_output = model(**tuple_inputs, return_dict=False, **additional_kwargs)
335
                output_ = model(**dict_inputs, return_dict=True, **additional_kwargs)
336
                attributes_ = vars(output_)
337
                dict_output = tuple(attributes_.values())
338

339
                def recursive_check(tuple_object, dict_object):
340
                    if isinstance(tuple_object, (List, Tuple)):
341
                        for tuple_iterable_value, dict_iterable_value in zip(tuple_object, dict_object):
342
                            recursive_check(tuple_iterable_value, dict_iterable_value)
343
                    elif isinstance(tuple_object, Dict):
344
                        for tuple_iterable_value, dict_iterable_value in zip(
345
                            tuple_object.values(), dict_object.values()
346
                        ):
347
                            recursive_check(tuple_iterable_value, dict_iterable_value)
348
                    elif tuple_object is None:
349
                        return
350
                    else:
351
                        self.assertTrue(
352
                            torch.allclose(
353
                                set_nan_tensor_to_zero(tuple_object),
354
                                set_nan_tensor_to_zero(dict_object),
355
                                atol=1e-5,
356
                            ),
357
                            msg=(
358
                                "Tuple and dict output are not equal. Difference:"
359
                                f" {torch.max(torch.abs(tuple_object - dict_object))}. Tuple has `nan`:"
360
                                f" {torch.isnan(tuple_object).any()} and `inf`: {torch.isinf(tuple_object)}. Dict has"
361
                                f" `nan`: {torch.isnan(dict_object).any()} and `inf`: {torch.isinf(dict_object)}."
362
                            ),
363
                        )
364

365
                recursive_check(tuple_output, dict_output)
366

367
        for model_class in self.all_model_classes:
368
            print(model_class)
369
            model = model_class(config)
370
            model.to(torch_device)
371
            model.eval()
372

373
            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
374
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
375

376
            check_equivalence(model, tuple_inputs, dict_inputs)
377

378
            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
379
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
380
            check_equivalence(model, tuple_inputs, dict_inputs)
381

382
            tuple_inputs = self._prepare_for_class(inputs_dict, model_class)
383
            dict_inputs = self._prepare_for_class(inputs_dict, model_class)
384
            tuple_inputs.update({"output_hidden_states": False})
385
            dict_inputs.update({"output_hidden_states": False})
386
            check_equivalence(model, tuple_inputs, dict_inputs)
387

388
            tuple_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
389
            dict_inputs = self._prepare_for_class(inputs_dict, model_class, return_labels=True)
390
            tuple_inputs.update({"output_hidden_states": False})
391
            dict_inputs.update({"output_hidden_states": False})
392
            check_equivalence(
393
                model,
394
                tuple_inputs,
395
                dict_inputs,
396
            )
397

398
    def test_model_main_input_name(self):
399
        model_signature = inspect.signature(getattr(PatchTSMixerModel, "forward"))
400
        # The main input is the name of the argument after `self`
401
        observed_main_input_name = list(model_signature.parameters.keys())[1]
402
        self.assertEqual(PatchTSMixerModel.main_input_name, observed_main_input_name)
403

404
    def test_forward_signature(self):
405
        config, _ = self.model_tester.prepare_config_and_inputs_for_common()
406

407
        for model_class in self.all_model_classes:
408
            model = model_class(config)
409
            signature = inspect.signature(model.forward)
410
            # signature.parameters is an OrderedDict => so arg_names order is deterministic
411
            arg_names = [*signature.parameters.keys()]
412

413
            if model_class == PatchTSMixerForPretraining:
414
                expected_arg_names = [
415
                    "past_values",
416
                    "observed_mask",
417
                    "output_hidden_states",
418
                    "return_loss",
419
                ]
420
            elif model_class == PatchTSMixerModel:
421
                expected_arg_names = [
422
                    "past_values",
423
                    "observed_mask",
424
                    "output_hidden_states",
425
                ]
426
            elif model_class in get_values(MODEL_FOR_TIME_SERIES_CLASSIFICATION_MAPPING) or model_class in get_values(
427
                MODEL_FOR_TIME_SERIES_REGRESSION_MAPPING
428
            ):
429
                expected_arg_names = [
430
                    "past_values",
431
                    "target_values",
432
                    "output_hidden_states",
433
                    "return_loss",
434
                ]
435
            else:
436
                # PatchTSMixerForPrediction
437
                expected_arg_names = [
438
                    "past_values",
439
                    "observed_mask",
440
                    "future_values",
441
                    "output_hidden_states",
442
                    "return_loss",
443
                ]
444

445
            self.assertListEqual(arg_names[: len(expected_arg_names)], expected_arg_names)
446

447
    @is_flaky()
448
    def test_retain_grad_hidden_states_attentions(self):
449
        super().test_retain_grad_hidden_states_attentions()
450

451

452
def prepare_batch(repo_id="ibm/patchtsmixer-etth1-test-data", file="pretrain_batch.pt"):
453
    # TODO: Make repo public
454
    file = hf_hub_download(repo_id=repo_id, filename=file, repo_type="dataset")
455
    batch = torch.load(file, map_location=torch_device)
456
    return batch
457

458

459
@require_torch
460
@slow
461
class PatchTSMixerModelIntegrationTests(unittest.TestCase):
462
    def test_pretrain_head(self):
463
        model = PatchTSMixerForPretraining.from_pretrained("ibm/patchtsmixer-etth1-pretrain").to(torch_device)
464
        batch = prepare_batch()
465

466
        torch.manual_seed(0)
467
        with torch.no_grad():
468
            output = model(past_values=batch["past_values"].to(torch_device)).prediction_outputs
469
        num_patch = (
470
            max(model.config.context_length, model.config.patch_length) - model.config.patch_length
471
        ) // model.config.patch_stride + 1
472
        expected_shape = torch.Size(
473
            [
474
                64,
475
                model.config.num_input_channels,
476
                num_patch,
477
                model.config.patch_length,
478
            ]
479
        )
480
        self.assertEqual(output.shape, expected_shape)
481

482
        expected_slice = torch.tensor([[[[-0.9106]],[[1.5326]],[[-0.8245]],[[0.7439]],[[-0.7830]],[[2.6256]],[[-0.6485]],]],device=torch_device)  # fmt: skip
483
        self.assertTrue(torch.allclose(output[0, :7, :1, :1], expected_slice, atol=TOLERANCE))
484

485
    def test_forecasting_head(self):
486
        model = PatchTSMixerForPrediction.from_pretrained("ibm/patchtsmixer-etth1-forecasting").to(torch_device)
487
        batch = prepare_batch(file="forecast_batch.pt")
488

489
        model.eval()
490
        torch.manual_seed(0)
491
        with torch.no_grad():
492
            output = model(
493
                past_values=batch["past_values"].to(torch_device),
494
                future_values=batch["future_values"].to(torch_device),
495
            ).prediction_outputs
496

497
        expected_shape = torch.Size([64, model.config.prediction_length, model.config.num_input_channels])
498
        self.assertEqual(output.shape, expected_shape)
499

500
        expected_slice = torch.tensor(
501
            [[0.2471, 0.5036, 0.3596, 0.5401, -0.0985, 0.3423, -0.8439]],
502
            device=torch_device,
503
        )
504
        self.assertTrue(torch.allclose(output[0, :1, :7], expected_slice, atol=TOLERANCE))
505

506
    def test_prediction_generation(self):
507
        model = PatchTSMixerForPrediction.from_pretrained("ibm/patchtsmixer-etth1-generate").to(torch_device)
508
        batch = prepare_batch(file="forecast_batch.pt")
509
        print(batch["past_values"])
510

511
        torch.manual_seed(0)
512
        model.eval()
513
        with torch.no_grad():
514
            outputs = model.generate(past_values=batch["past_values"].to(torch_device))
515
        expected_shape = torch.Size((64, 1, model.config.prediction_length, model.config.num_input_channels))
516

517
        self.assertEqual(outputs.sequences.shape, expected_shape)
518

519
        expected_slice = torch.tensor(
520
            [[0.4308, -0.4731, 1.3512, -0.1038, -0.4655, 1.1279, -0.7179]],
521
            device=torch_device,
522
        )
523

524
        mean_prediction = outputs.sequences.mean(dim=1)
525

526
        self.assertTrue(torch.allclose(mean_prediction[0, -1:], expected_slice, atol=TOLERANCE))
527

528

529
@require_torch
530
class PatchTSMixerFunctionalTests(unittest.TestCase):
531
    @classmethod
532
    def setUpClass(cls):
533
        """Setup method: Called once before test-cases execution"""
534
        cls.params = {}
535
        cls.params.update(
536
            context_length=32,
537
            patch_length=8,
538
            num_input_channels=3,
539
            patch_stride=8,
540
            d_model=4,
541
            expansion_factor=2,
542
            num_layers=3,
543
            dropout=0.2,
544
            mode="common_channel",  # common_channel,  mix_channel
545
            gated_attn=True,
546
            norm_mlp="LayerNorm",
547
            mask_type="random",
548
            random_mask_ratio=0.5,
549
            mask_patches=[2, 3],
550
            forecast_mask_ratios=[1, 1],
551
            mask_value=0,
552
            masked_loss=True,
553
            channel_consistent_masking=True,
554
            head_dropout=0.2,
555
            prediction_length=64,
556
            out_channels=None,
557
            # num_labels=3,
558
            num_targets=3,
559
            output_range=None,
560
            head_aggregation=None,
561
            scaling="std",
562
            use_positional_encoding=False,
563
            positional_encoding="sincos",
564
            self_attn=False,
565
            self_attn_heads=1,
566
            num_parallel_samples=4,
567
        )
568

569
        cls.num_patches = (
570
            max(cls.params["context_length"], cls.params["patch_length"]) - cls.params["patch_length"]
571
        ) // cls.params["patch_stride"] + 1
572

573
        # batch_size = 32
574
        batch_size = 2
575

576
        int(cls.params["prediction_length"] / cls.params["patch_length"])
577

578
        cls.data = torch.rand(
579
            batch_size,
580
            cls.params["context_length"],
581
            cls.params["num_input_channels"],
582
        )
583

584
        cls.enc_data = torch.rand(
585
            batch_size,
586
            cls.params["num_input_channels"],
587
            cls.num_patches,
588
            cls.params["patch_length"],
589
        )
590

591
        cls.enc_output = torch.rand(
592
            batch_size,
593
            cls.params["num_input_channels"],
594
            cls.num_patches,
595
            cls.params["d_model"],
596
        )
597

598
        cls.flat_enc_output = torch.rand(
599
            batch_size,
600
            cls.num_patches,
601
            cls.params["d_model"],
602
        )
603

604
        cls.correct_pred_output = torch.rand(
605
            batch_size,
606
            cls.params["prediction_length"],
607
            cls.params["num_input_channels"],
608
        )
609
        cls.correct_regression_output = torch.rand(batch_size, cls.params["num_targets"])
610

611
        cls.correct_pretrain_output = torch.rand(
612
            batch_size,
613
            cls.params["num_input_channels"],
614
            cls.num_patches,
615
            cls.params["patch_length"],
616
        )
617

618
        cls.correct_forecast_output = torch.rand(
619
            batch_size,
620
            cls.params["prediction_length"],
621
            cls.params["num_input_channels"],
622
        )
623

624
        cls.correct_sel_forecast_output = torch.rand(batch_size, cls.params["prediction_length"], 2)
625

626
        cls.correct_classification_output = torch.rand(
627
            batch_size,
628
            cls.params["num_targets"],
629
        )
630

631
        cls.correct_classification_classes = torch.randint(0, cls.params["num_targets"], (batch_size,))
632

633
    def test_patchtsmixer_encoder(self):
634
        config = PatchTSMixerConfig(**self.__class__.params)
635
        enc = PatchTSMixerEncoder(config)
636
        output = enc(self.__class__.enc_data)
637
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
638

639
    def test_patchmodel(self):
640
        config = PatchTSMixerConfig(**self.__class__.params)
641
        mdl = PatchTSMixerModel(config)
642
        output = mdl(self.__class__.data)
643
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
644
        self.assertEqual(output.patch_input.shape, self.__class__.enc_data.shape)
645

646
    def test_pretrainhead(self):
647
        config = PatchTSMixerConfig(**self.__class__.params)
648
        head = PatchTSMixerPretrainHead(
649
            config=config,
650
        )
651
        output = head(self.__class__.enc_output)
652

653
        self.assertEqual(output.shape, self.__class__.correct_pretrain_output.shape)
654

655
    def test_pretrain_full(self):
656
        config = PatchTSMixerConfig(**self.__class__.params)
657
        mdl = PatchTSMixerForPretraining(config)
658
        output = mdl(self.__class__.data)
659
        self.assertEqual(
660
            output.prediction_outputs.shape,
661
            self.__class__.correct_pretrain_output.shape,
662
        )
663
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
664
        self.assertEqual(output.loss.item() < np.inf, True)
665

666
    def test_pretrain_full_with_return_dict(self):
667
        config = PatchTSMixerConfig(**self.__class__.params)
668
        mdl = PatchTSMixerForPretraining(config)
669
        output = mdl(self.__class__.data, return_dict=False)
670
        self.assertEqual(output[1].shape, self.__class__.correct_pretrain_output.shape)
671
        self.assertEqual(output[2].shape, self.__class__.enc_output.shape)
672
        self.assertEqual(output[0].item() < np.inf, True)
673

674
    def test_forecast_head(self):
675
        config = PatchTSMixerConfig(**self.__class__.params)
676
        head = PatchTSMixerForPredictionHead(
677
            config=config,
678
        )
679
        # output = head(self.__class__.enc_output, raw_data = self.__class__.correct_pretrain_output)
680
        output = head(self.__class__.enc_output)
681

682
        self.assertEqual(output.shape, self.__class__.correct_forecast_output.shape)
683

684
    def check_module(
685
        self,
686
        task,
687
        params=None,
688
        output_hidden_states=True,
689
    ):
690
        config = PatchTSMixerConfig(**params)
691
        if task == "forecast":
692
            mdl = PatchTSMixerForPrediction(config)
693
            target_input = self.__class__.correct_forecast_output
694
            if config.prediction_channel_indices is not None:
695
                target_output = self.__class__.correct_sel_forecast_output
696
            else:
697
                target_output = target_input
698
            ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1, -1)
699
            ground_truth_arg = "future_values"
700
            output_predictions_arg = "prediction_outputs"
701
        elif task == "classification":
702
            mdl = PatchTSMixerForTimeSeriesClassification(config)
703
            target_input = self.__class__.correct_classification_classes
704
            target_output = self.__class__.correct_classification_output
705
            ground_truth_arg = "target_values"
706
            output_predictions_arg = "prediction_outputs"
707
        elif task == "regression":
708
            mdl = PatchTSMixerForRegression(config)
709
            target_input = self.__class__.correct_regression_output
710
            target_output = self.__class__.correct_regression_output
711
            ref_samples = target_output.unsqueeze(1).expand(-1, config.num_parallel_samples, -1)
712
            ground_truth_arg = "target_values"
713
            output_predictions_arg = "regression_outputs"
714
        elif task == "pretrain":
715
            mdl = PatchTSMixerForPretraining(config)
716
            target_input = None
717
            target_output = self.__class__.correct_pretrain_output
718
            ground_truth_arg = None
719
            output_predictions_arg = "prediction_outputs"
720
        else:
721
            print("invalid task")
722

723
        enc_output = self.__class__.enc_output
724

725
        if target_input is None:
726
            output = mdl(self.__class__.data, output_hidden_states=output_hidden_states)
727
        else:
728
            output = mdl(
729
                self.__class__.data,
730
                **{
731
                    ground_truth_arg: target_input,
732
                    "output_hidden_states": output_hidden_states,
733
                },
734
            )
735

736
        prediction_outputs = getattr(output, output_predictions_arg)
737
        if isinstance(prediction_outputs, tuple):
738
            for t in prediction_outputs:
739
                self.assertEqual(t.shape, target_output.shape)
740
        else:
741
            self.assertEqual(prediction_outputs.shape, target_output.shape)
742

743
        self.assertEqual(output.last_hidden_state.shape, enc_output.shape)
744

745
        if output_hidden_states is True:
746
            self.assertEqual(len(output.hidden_states), params["num_layers"])
747

748
        else:
749
            self.assertEqual(output.hidden_states, None)
750

751
        self.assertEqual(output.loss.item() < np.inf, True)
752

753
        if config.loss == "nll" and task in ["forecast", "regression"]:
754
            samples = mdl.generate(self.__class__.data)
755
            self.assertEqual(samples.sequences.shape, ref_samples.shape)
756

757
    @parameterized.expand(
758
        list(
759
            itertools.product(
760
                ["common_channel", "mix_channel"],
761
                [True, False],
762
                [True, False, "mean", "std"],
763
                [True, False],
764
                [None, [0, 2]],
765
                ["mse", "nll"],
766
            )
767
        )
768
    )
769
    def test_forecast(self, mode, self_attn, scaling, gated_attn, prediction_channel_indices, loss):
770
        params = self.__class__.params.copy()
771
        params.update(
772
            mode=mode,
773
            self_attn=self_attn,
774
            scaling=scaling,
775
            prediction_channel_indices=prediction_channel_indices,
776
            gated_attn=gated_attn,
777
            loss=loss,
778
        )
779

780
        self.check_module(task="forecast", params=params)
781

782
    @parameterized.expand(
783
        list(
784
            itertools.product(
785
                ["common_channel", "mix_channel"],
786
                [True, False],
787
                [True, False, "mean", "std"],
788
                [True, False],
789
                ["max_pool", "avg_pool"],
790
            )
791
        )
792
    )
793
    def test_classification(self, mode, self_attn, scaling, gated_attn, head_aggregation):
794
        params = self.__class__.params.copy()
795
        params.update(
796
            mode=mode,
797
            self_attn=self_attn,
798
            scaling=scaling,
799
            head_aggregation=head_aggregation,
800
            gated_attn=gated_attn,
801
        )
802

803
        self.check_module(task="classification", params=params)
804

805
    @parameterized.expand(
806
        list(
807
            itertools.product(
808
                ["common_channel", "mix_channel"],
809
                [True, False],
810
                [True, False, "mean", "std"],
811
                [True, False],
812
                ["max_pool", "avg_pool"],
813
                ["mse", "nll"],
814
            )
815
        )
816
    )
817
    def test_regression(self, mode, self_attn, scaling, gated_attn, head_aggregation, loss):
818
        params = self.__class__.params.copy()
819
        params.update(
820
            mode=mode,
821
            self_attn=self_attn,
822
            scaling=scaling,
823
            head_aggregation=head_aggregation,
824
            gated_attn=gated_attn,
825
            loss=loss,
826
        )
827

828
        self.check_module(task="regression", params=params)
829

830
    @parameterized.expand(
831
        list(
832
            itertools.product(
833
                ["common_channel", "mix_channel"],
834
                [True, False],
835
                [True, False, "mean", "std"],
836
                [True, False],
837
                ["random", "forecast"],
838
                [True, False],
839
                [True, False],
840
            )
841
        )
842
    )
843
    def test_pretrain(
844
        self,
845
        mode,
846
        self_attn,
847
        scaling,
848
        gated_attn,
849
        mask_type,
850
        masked_loss,
851
        channel_consistent_masking,
852
    ):
853
        params = self.__class__.params.copy()
854
        params.update(
855
            mode=mode,
856
            self_attn=self_attn,
857
            scaling=scaling,
858
            gated_attn=gated_attn,
859
            mask_type=mask_type,
860
            masked_loss=masked_loss,
861
            channel_consistent_masking=channel_consistent_masking,
862
        )
863

864
        self.check_module(task="pretrain", params=params)
865

866
    def forecast_full_module(self, params=None, output_hidden_states=False, return_dict=None):
867
        config = PatchTSMixerConfig(**params)
868
        mdl = PatchTSMixerForPrediction(config)
869

870
        target_val = self.__class__.correct_forecast_output
871

872
        if config.prediction_channel_indices is not None:
873
            target_val = self.__class__.correct_sel_forecast_output
874

875
        enc_output = self.__class__.enc_output
876

877
        output = mdl(
878
            self.__class__.data,
879
            future_values=self.__class__.correct_forecast_output,
880
            output_hidden_states=output_hidden_states,
881
            return_dict=return_dict,
882
        )
883

884
        if isinstance(output, tuple):
885
            output = PatchTSMixerForPredictionOutput(*output)
886

887
        if config.loss == "mse":
888
            self.assertEqual(output.prediction_outputs.shape, target_val.shape)
889

890
        self.assertEqual(output.last_hidden_state.shape, enc_output.shape)
891

892
        if output_hidden_states is True:
893
            self.assertEqual(len(output.hidden_states), params["num_layers"])
894

895
        else:
896
            self.assertEqual(output.hidden_states, None)
897

898
        self.assertEqual(output.loss.item() < np.inf, True)
899

900
        if config.loss == "nll":
901
            samples = mdl.generate(self.__class__.data)
902
            ref_samples = target_val.unsqueeze(1).expand(-1, params["num_parallel_samples"], -1, -1)
903
            self.assertEqual(samples.sequences.shape, ref_samples.shape)
904

905
    def test_forecast_full(self):
906
        self.check_module(task="forecast", params=self.__class__.params, output_hidden_states=True)
907
        # self.forecast_full_module(self.__class__.params, output_hidden_states = True)
908

909
    def test_forecast_full_2(self):
910
        params = self.__class__.params.copy()
911
        params.update(
912
            mode="mix_channel",
913
        )
914
        self.forecast_full_module(params, output_hidden_states=True)
915

916
    def test_forecast_full_2_with_return_dict(self):
917
        params = self.__class__.params.copy()
918
        params.update(
919
            mode="mix_channel",
920
        )
921
        self.forecast_full_module(params, output_hidden_states=True, return_dict=False)
922

923
    def test_forecast_full_3(self):
924
        params = self.__class__.params.copy()
925
        params.update(
926
            mode="mix_channel",
927
        )
928
        self.forecast_full_module(params, output_hidden_states=True)
929

930
    def test_forecast_full_5(self):
931
        params = self.__class__.params.copy()
932
        params.update(
933
            self_attn=True,
934
            use_positional_encoding=True,
935
            positional_encoding="sincos",
936
        )
937
        self.forecast_full_module(params, output_hidden_states=True)
938

939
    def test_forecast_full_4(self):
940
        params = self.__class__.params.copy()
941
        params.update(
942
            mode="mix_channel",
943
            prediction_channel_indices=[0, 2],
944
        )
945
        self.forecast_full_module(params)
946

947
    def test_forecast_full_distributional(self):
948
        params = self.__class__.params.copy()
949
        params.update(
950
            mode="mix_channel",
951
            prediction_channel_indices=[0, 2],
952
            loss="nll",
953
            distribution_output="normal",
954
        )
955

956
        self.forecast_full_module(params)
957

958
    def test_forecast_full_distributional_2(self):
959
        params = self.__class__.params.copy()
960
        params.update(
961
            mode="mix_channel",
962
            prediction_channel_indices=[0, 2],
963
            loss="nll",
964
            # distribution_output = "normal",
965
        )
966
        self.forecast_full_module(params)
967

968
    def test_forecast_full_distributional_3(self):
969
        params = self.__class__.params.copy()
970
        params.update(
971
            mode="mix_channel",
972
            # prediction_channel_indices=[0, 2],
973
            loss="nll",
974
            distribution_output="normal",
975
        )
976
        self.forecast_full_module(params)
977

978
    def test_forecast_full_distributional_4(self):
979
        params = self.__class__.params.copy()
980
        params.update(
981
            mode="mix_channel",
982
            # prediction_channel_indices=[0, 2],
983
            loss="nll",
984
            distribution_output="normal",
985
        )
986
        self.forecast_full_module(params)
987

988
    def test_classification_head(self):
989
        config = PatchTSMixerConfig(**self.__class__.params)
990
        head = PatchTSMixerLinearHead(
991
            config=config,
992
        )
993
        # output = head(self.__class__.enc_output, raw_data = self.__class__.correct_pretrain_output)
994
        output = head(self.__class__.enc_output)
995

996
        self.assertEqual(output.shape, self.__class__.correct_classification_output.shape)
997

998
    def test_classification_full(self):
999
        config = PatchTSMixerConfig(**self.__class__.params)
1000
        mdl = PatchTSMixerForTimeSeriesClassification(config)
1001
        output = mdl(
1002
            self.__class__.data,
1003
            target_values=self.__class__.correct_classification_classes,
1004
        )
1005
        self.assertEqual(
1006
            output.prediction_outputs.shape,
1007
            self.__class__.correct_classification_output.shape,
1008
        )
1009
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1010
        self.assertEqual(output.loss.item() < np.inf, True)
1011

1012
    def test_classification_full_with_return_dict(self):
1013
        config = PatchTSMixerConfig(**self.__class__.params)
1014
        mdl = PatchTSMixerForTimeSeriesClassification(config)
1015
        output = mdl(
1016
            self.__class__.data,
1017
            target_values=self.__class__.correct_classification_classes,
1018
            return_dict=False,
1019
        )
1020
        if isinstance(output, tuple):
1021
            output = PatchTSMixerForTimeSeriesClassificationOutput(*output)
1022
        self.assertEqual(
1023
            output.prediction_outputs.shape,
1024
            self.__class__.correct_classification_output.shape,
1025
        )
1026
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1027
        self.assertEqual(output.loss.item() < np.inf, True)
1028

1029
    def test_regression_head(self):
1030
        config = PatchTSMixerConfig(**self.__class__.params)
1031
        head = PatchTSMixerLinearHead(
1032
            config=config,
1033
        )
1034
        output = head(self.__class__.enc_output)
1035
        self.assertEqual(output.shape, self.__class__.correct_regression_output.shape)
1036

1037
    def test_regression_full(self):
1038
        config = PatchTSMixerConfig(**self.__class__.params)
1039
        mdl = PatchTSMixerForRegression(config)
1040
        output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
1041
        self.assertEqual(
1042
            output.regression_outputs.shape,
1043
            self.__class__.correct_regression_output.shape,
1044
        )
1045
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1046
        self.assertEqual(output.loss.item() < np.inf, True)
1047

1048
    def test_regression_full_with_return_dict(self):
1049
        config = PatchTSMixerConfig(**self.__class__.params)
1050
        mdl = PatchTSMixerForRegression(config)
1051
        output = mdl(
1052
            self.__class__.data,
1053
            target_values=self.__class__.correct_regression_output,
1054
            return_dict=False,
1055
        )
1056
        if isinstance(output, tuple):
1057
            output = PatchTSMixerForRegressionOutput(*output)
1058
        self.assertEqual(
1059
            output.regression_outputs.shape,
1060
            self.__class__.correct_regression_output.shape,
1061
        )
1062
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1063
        self.assertEqual(output.loss.item() < np.inf, True)
1064

1065
    def test_regression_full_distribute(self):
1066
        params = self.__class__.params.copy()
1067
        params.update(loss="nll", distribution_output="normal")
1068

1069
        config = PatchTSMixerConfig(**params)
1070

1071
        mdl = PatchTSMixerForRegression(config)
1072
        output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
1073
        self.assertEqual(
1074
            output.regression_outputs[0].shape,
1075
            self.__class__.correct_regression_output.shape,
1076
        )
1077
        self.assertEqual(
1078
            output.regression_outputs[1].shape,
1079
            self.__class__.correct_regression_output.shape,
1080
        )
1081
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1082
        self.assertEqual(output.loss.item() < np.inf, True)
1083

1084
        if config.loss == "nll":
1085
            samples = mdl.generate(self.__class__.data)
1086
            ref_samples = self.__class__.correct_regression_output.unsqueeze(1).expand(
1087
                -1, params["num_parallel_samples"], -1
1088
            )
1089
            self.assertEqual(samples.sequences.shape, ref_samples.shape)
1090

1091
    def test_regression_full_distribute_2(self):
1092
        params = self.__class__.params.copy()
1093
        params.update(loss="nll", distribution_output="student_t")
1094

1095
        config = PatchTSMixerConfig(**params)
1096

1097
        mdl = PatchTSMixerForRegression(config)
1098
        output = mdl(self.__class__.data, target_values=self.__class__.correct_regression_output)
1099
        self.assertEqual(
1100
            output.regression_outputs[0].shape,
1101
            self.__class__.correct_regression_output.shape,
1102
        )
1103
        self.assertEqual(
1104
            output.regression_outputs[1].shape,
1105
            self.__class__.correct_regression_output.shape,
1106
        )
1107
        self.assertEqual(output.last_hidden_state.shape, self.__class__.enc_output.shape)
1108
        self.assertEqual(output.loss.item() < np.inf, True)
1109

1110
        if config.loss == "nll":
1111
            samples = mdl.generate(self.__class__.data)
1112
            ref_samples = self.__class__.correct_regression_output.unsqueeze(1).expand(
1113
                -1, params["num_parallel_samples"], -1
1114
            )
1115
            self.assertEqual(samples.sequences.shape, ref_samples.shape)
1116

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.