pytorch

Форк
0
/
test_scheduler.py 
179 строк · 6.8 Кб
1
# Owner(s): ["module: unknown"]
2

3
from torch import nn
4

5
from torch.ao.pruning import WeightNormSparsifier
6
from torch.ao.pruning import BaseScheduler, LambdaSL, CubicSL
7

8
from torch.testing._internal.common_utils import TestCase
9

10
import warnings
11

12
class ImplementedScheduler(BaseScheduler):
13
    def get_sl(self):
14
        if self.last_epoch > 0:
15
            return [group['sparsity_level'] * 0.5
16
                    for group in self.sparsifier.groups]
17
        else:
18
            return list(self.base_sl)
19

20

21
class TestScheduler(TestCase):
22
    def test_constructor(self):
23
        model = nn.Sequential(
24
            nn.Linear(16, 16)
25
        )
26
        sparsifier = WeightNormSparsifier()
27
        sparsifier.prepare(model, config=None)
28
        scheduler = ImplementedScheduler(sparsifier)
29

30
        assert scheduler.sparsifier is sparsifier
31
        assert scheduler._step_count == 1
32
        assert scheduler.base_sl == [sparsifier.groups[0]['sparsity_level']]
33

34
    def test_order_of_steps(self):
35
        """Checks if the warning is thrown if the scheduler step is called
36
        before the sparsifier step"""
37

38
        model = nn.Sequential(
39
            nn.Linear(16, 16)
40
        )
41
        sparsifier = WeightNormSparsifier()
42
        sparsifier.prepare(model, config=None)
43
        scheduler = ImplementedScheduler(sparsifier)
44

45
        # Sparsifier step is not called
46
        with self.assertWarns(UserWarning):
47
            scheduler.step()
48

49
        # Correct order has no warnings
50
        # Note: This will trigger if other warnings are present.
51
        with warnings.catch_warnings(record=True) as w:
52
            sparsifier.step()
53
            scheduler.step()
54
            # Make sure there is no warning related to the base_scheduler
55
            for warning in w:
56
                fname = warning.filename
57
                fname = '/'.join(fname.split('/')[-5:])
58
                assert fname != 'torch/ao/sparsity/scheduler/base_scheduler.py'
59

60
    def test_step(self):
61
        model = nn.Sequential(
62
            nn.Linear(16, 16)
63
        )
64
        sparsifier = WeightNormSparsifier()
65
        sparsifier.prepare(model, config=None)
66
        assert sparsifier.groups[0]['sparsity_level'] == 0.5
67
        scheduler = ImplementedScheduler(sparsifier)
68
        assert sparsifier.groups[0]['sparsity_level'] == 0.5
69

70
        sparsifier.step()
71
        scheduler.step()
72
        assert sparsifier.groups[0]['sparsity_level'] == 0.25
73

74
    def test_lambda_scheduler(self):
75
        model = nn.Sequential(
76
            nn.Linear(16, 16)
77
        )
78
        sparsifier = WeightNormSparsifier()
79
        sparsifier.prepare(model, config=None)
80
        assert sparsifier.groups[0]['sparsity_level'] == 0.5
81
        scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10)
82
        assert sparsifier.groups[0]['sparsity_level'] == 0.0  # Epoch 0
83
        scheduler.step()
84
        assert sparsifier.groups[0]['sparsity_level'] == 5.0  # Epoch 1
85

86

87
class TestCubicScheduler(TestCase):
88
    def setUp(self):
89
        self.model_sparse_config = [
90
            {'tensor_fqn': '0.weight', 'sparsity_level': 0.8},
91
            {'tensor_fqn': '2.weight', 'sparsity_level': 0.4},
92
        ]
93
        self.sorted_sparse_levels = [conf['sparsity_level'] for conf in self.model_sparse_config]
94
        self.initial_sparsity = 0.1
95
        self.initial_step = 3
96

97
    def _make_model(self, **kwargs):
98
        model = nn.Sequential(
99
            nn.Linear(13, 17),
100
            nn.Dropout(0.5),
101
            nn.Linear(17, 3),
102
        )
103
        return model
104

105
    def _make_scheduler(self, model, **kwargs):
106
        sparsifier = WeightNormSparsifier()
107
        sparsifier.prepare(model, config=self.model_sparse_config)
108

109
        scheduler_args = {
110
            'init_sl': self.initial_sparsity,
111
            'init_t': self.initial_step,
112
        }
113
        scheduler_args.update(kwargs)
114

115
        scheduler = CubicSL(sparsifier, **scheduler_args)
116
        return sparsifier, scheduler
117

118
    @staticmethod
119
    def _get_sparsity_levels(sparsifier, precision=32):
120
        r"""Gets the current levels of sparsity in a sparsifier."""
121
        return [round(group['sparsity_level'], precision) for group in sparsifier.groups]
122

123
    def test_constructor(self):
124
        model = self._make_model()
125
        sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True)
126
        self.assertIs(
127
            scheduler.sparsifier, sparsifier,
128
            msg="Sparsifier is not properly attached")
129
        self.assertEqual(
130
            scheduler._step_count, 1,
131
            msg="Scheduler is initialized with incorrect step count")
132
        self.assertEqual(
133
            scheduler.base_sl, self.sorted_sparse_levels,
134
            msg="Scheduler did not store the target sparsity levels correctly")
135

136
        # Value before t_0 is 0
137
        self.assertEqual(
138
            self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(0.0),
139
            msg="Sparsifier is not reset correctly after attaching to the Scheduler")
140

141
        # Value before t_0 is s_0
142
        model = self._make_model()
143
        sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=False)
144
        self.assertEqual(
145
            self._get_sparsity_levels(sparsifier),
146
            scheduler._make_sure_a_list(self.initial_sparsity),
147
            msg="Sparsifier is not reset correctly after attaching to the Scheduler")
148

149
    def test_step(self):
150
        # For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0
151
        model = self._make_model()
152
        sparsifier, scheduler = self._make_scheduler(
153
            model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5)
154

155
        scheduler.step()
156
        scheduler.step()
157
        self.assertEqual(scheduler._step_count, 3, msg="Scheduler step_count is expected to increment")
158
        # Value before t_0 is supposed to be 0
159
        self.assertEqual(
160
            self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(0.0),
161
            msg="Scheduler step updating the sparsity level before t_0")
162

163
        scheduler.step()  # Step = 3  =>  sparsity = initial_sparsity
164
        self.assertEqual(
165
            self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(self.initial_sparsity),
166
            msg="Sparsifier is not reset to initial sparsity at the first step")
167

168
        scheduler.step()  # Step = 4  =>  sparsity ~ [0.3, 0.2]
169
        self.assertEqual(
170
            self._get_sparsity_levels(sparsifier, 1), [0.3, 0.2],
171
            msg="Sparsity level is not set correctly after the first step")
172

173
        current_step = scheduler._step_count - scheduler.init_t[0] - 1
174
        more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step
175
        for _ in range(more_steps_needed):  # More steps needed to final sparsity level
176
            scheduler.step()
177
        self.assertEqual(
178
            self._get_sparsity_levels(sparsifier), self.sorted_sparse_levels,
179
            msg="Sparsity level is not reaching the target level afer delta_t * n steps ")
180

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.