4
from mergekit.sparsify import SparsificationMethod, sparsify
9
res = torch.randn(128, 64)
10
res[res == 0] = 7 # very low chance, but hey!
15
def test_full_density(self, sample_tensor):
17
sparsify(sample_tensor, density=1, method=SparsificationMethod.magnitude),
21
def test_zero_density(self, sample_tensor):
22
with pytest.raises(AssertionError):
23
sparsify(sample_tensor, density=0, method=SparsificationMethod.magnitude)
25
def test_partial_density(self, sample_tensor):
27
sample_tensor, density=0.5, method=SparsificationMethod.magnitude
29
assert torch.count_nonzero(result) == sample_tensor.view(-1).shape[0] // 2
35
def test_bernoulli_with_rescale(self, sample_tensor):
36
ref_abs_sum = sample_tensor.abs().sum()
37
avg_abs_sum = torch.zeros_like(ref_abs_sum)
38
for _ in range(TestBernoulli.NUM_ITERATIONS):
40
sample_tensor, density=0.5, method=SparsificationMethod.rescaled_random
42
avg_abs_sum += rescaled.abs().sum()
43
avg_abs_sum /= TestBernoulli.NUM_ITERATIONS
45
assert torch.isclose(avg_abs_sum, ref_abs_sum, rtol=0.01)
47
def test_bernoulli_without_rescale(self, sample_tensor):
49
sample_tensor, density=0.5, method=SparsificationMethod.random
51
assert 0 < torch.count_nonzero(result) <= sample_tensor.view(-1).shape[0]
53
def test_cpu_dtypes(self, sample_tensor):
54
for dt in (torch.float16, torch.bfloat16, torch.float32):
56
tensor=sample_tensor.to(dtype=dt).cpu(),
58
method=SparsificationMethod.rescaled_random,