1
# Owner(s): ["module: unknown"]
5
from torch.ao.pruning import WeightNormSparsifier
6
from torch.ao.pruning import BaseScheduler, LambdaSL, CubicSL
8
from torch.testing._internal.common_utils import TestCase
12
class ImplementedScheduler(BaseScheduler):
14
if self.last_epoch > 0:
15
return [group['sparsity_level'] * 0.5
16
for group in self.sparsifier.groups]
18
return list(self.base_sl)
21
class TestScheduler(TestCase):
22
def test_constructor(self):
23
model = nn.Sequential(
26
sparsifier = WeightNormSparsifier()
27
sparsifier.prepare(model, config=None)
28
scheduler = ImplementedScheduler(sparsifier)
30
assert scheduler.sparsifier is sparsifier
31
assert scheduler._step_count == 1
32
assert scheduler.base_sl == [sparsifier.groups[0]['sparsity_level']]
34
def test_order_of_steps(self):
35
"""Checks if the warning is thrown if the scheduler step is called
36
before the sparsifier step"""
38
model = nn.Sequential(
41
sparsifier = WeightNormSparsifier()
42
sparsifier.prepare(model, config=None)
43
scheduler = ImplementedScheduler(sparsifier)
45
# Sparsifier step is not called
46
with self.assertWarns(UserWarning):
49
# Correct order has no warnings
50
# Note: This will trigger if other warnings are present.
51
with warnings.catch_warnings(record=True) as w:
54
# Make sure there is no warning related to the base_scheduler
56
fname = warning.filename
57
fname = '/'.join(fname.split('/')[-5:])
58
assert fname != 'torch/ao/sparsity/scheduler/base_scheduler.py'
61
model = nn.Sequential(
64
sparsifier = WeightNormSparsifier()
65
sparsifier.prepare(model, config=None)
66
assert sparsifier.groups[0]['sparsity_level'] == 0.5
67
scheduler = ImplementedScheduler(sparsifier)
68
assert sparsifier.groups[0]['sparsity_level'] == 0.5
72
assert sparsifier.groups[0]['sparsity_level'] == 0.25
74
def test_lambda_scheduler(self):
75
model = nn.Sequential(
78
sparsifier = WeightNormSparsifier()
79
sparsifier.prepare(model, config=None)
80
assert sparsifier.groups[0]['sparsity_level'] == 0.5
81
scheduler = LambdaSL(sparsifier, lambda epoch: epoch * 10)
82
assert sparsifier.groups[0]['sparsity_level'] == 0.0 # Epoch 0
84
assert sparsifier.groups[0]['sparsity_level'] == 5.0 # Epoch 1
87
class TestCubicScheduler(TestCase):
89
self.model_sparse_config = [
90
{'tensor_fqn': '0.weight', 'sparsity_level': 0.8},
91
{'tensor_fqn': '2.weight', 'sparsity_level': 0.4},
93
self.sorted_sparse_levels = [conf['sparsity_level'] for conf in self.model_sparse_config]
94
self.initial_sparsity = 0.1
97
def _make_model(self, **kwargs):
98
model = nn.Sequential(
105
def _make_scheduler(self, model, **kwargs):
106
sparsifier = WeightNormSparsifier()
107
sparsifier.prepare(model, config=self.model_sparse_config)
110
'init_sl': self.initial_sparsity,
111
'init_t': self.initial_step,
113
scheduler_args.update(kwargs)
115
scheduler = CubicSL(sparsifier, **scheduler_args)
116
return sparsifier, scheduler
119
def _get_sparsity_levels(sparsifier, precision=32):
120
r"""Gets the current levels of sparsity in a sparsifier."""
121
return [round(group['sparsity_level'], precision) for group in sparsifier.groups]
123
def test_constructor(self):
124
model = self._make_model()
125
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=True)
127
scheduler.sparsifier, sparsifier,
128
msg="Sparsifier is not properly attached")
130
scheduler._step_count, 1,
131
msg="Scheduler is initialized with incorrect step count")
133
scheduler.base_sl, self.sorted_sparse_levels,
134
msg="Scheduler did not store the target sparsity levels correctly")
136
# Value before t_0 is 0
138
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(0.0),
139
msg="Sparsifier is not reset correctly after attaching to the Scheduler")
141
# Value before t_0 is s_0
142
model = self._make_model()
143
sparsifier, scheduler = self._make_scheduler(model=model, initially_zero=False)
145
self._get_sparsity_levels(sparsifier),
146
scheduler._make_sure_a_list(self.initial_sparsity),
147
msg="Sparsifier is not reset correctly after attaching to the Scheduler")
150
# For n=5, dt=2, there will be totally 10 steps between s_0 and s_f, starting from t_0
151
model = self._make_model()
152
sparsifier, scheduler = self._make_scheduler(
153
model=model, initially_zero=True, init_t=3, delta_t=2, total_t=5)
157
self.assertEqual(scheduler._step_count, 3, msg="Scheduler step_count is expected to increment")
158
# Value before t_0 is supposed to be 0
160
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(0.0),
161
msg="Scheduler step updating the sparsity level before t_0")
163
scheduler.step() # Step = 3 => sparsity = initial_sparsity
165
self._get_sparsity_levels(sparsifier), scheduler._make_sure_a_list(self.initial_sparsity),
166
msg="Sparsifier is not reset to initial sparsity at the first step")
168
scheduler.step() # Step = 4 => sparsity ~ [0.3, 0.2]
170
self._get_sparsity_levels(sparsifier, 1), [0.3, 0.2],
171
msg="Sparsity level is not set correctly after the first step")
173
current_step = scheduler._step_count - scheduler.init_t[0] - 1
174
more_steps_needed = scheduler.delta_t[0] * scheduler.total_t[0] - current_step
175
for _ in range(more_steps_needed): # More steps needed to final sparsity level
178
self._get_sparsity_levels(sparsifier), self.sorted_sparse_levels,
179
msg="Sparsity level is not reaching the target level afer delta_t * n steps ")