mergekit

Форк
0
/
test_sparsify.py 
59 строк · 1.9 Кб
1
import pytest
2
import torch
3

4
from mergekit.sparsify import SparsificationMethod, sparsify
5

6

7
@pytest.fixture
8
def sample_tensor():
9
    res = torch.randn(128, 64)
10
    res[res == 0] = 7  # very low chance, but hey!
11
    return res
12

13

14
class TestMagnitude:
15
    def test_full_density(self, sample_tensor):
16
        assert torch.equal(
17
            sparsify(sample_tensor, density=1, method=SparsificationMethod.magnitude),
18
            sample_tensor,
19
        )
20

21
    def test_zero_density(self, sample_tensor):
22
        with pytest.raises(AssertionError):
23
            sparsify(sample_tensor, density=0, method=SparsificationMethod.magnitude)
24

25
    def test_partial_density(self, sample_tensor):
26
        result = sparsify(
27
            sample_tensor, density=0.5, method=SparsificationMethod.magnitude
28
        )
29
        assert torch.count_nonzero(result) == sample_tensor.view(-1).shape[0] // 2
30

31

32
class TestBernoulli:
33
    NUM_ITERATIONS = 1000
34

35
    def test_bernoulli_with_rescale(self, sample_tensor):
36
        ref_abs_sum = sample_tensor.abs().sum()
37
        avg_abs_sum = torch.zeros_like(ref_abs_sum)
38
        for _ in range(TestBernoulli.NUM_ITERATIONS):
39
            rescaled = sparsify(
40
                sample_tensor, density=0.5, method=SparsificationMethod.rescaled_random
41
            )
42
            avg_abs_sum += rescaled.abs().sum()
43
        avg_abs_sum /= TestBernoulli.NUM_ITERATIONS
44

45
        assert torch.isclose(avg_abs_sum, ref_abs_sum, rtol=0.01)
46

47
    def test_bernoulli_without_rescale(self, sample_tensor):
48
        result = sparsify(
49
            sample_tensor, density=0.5, method=SparsificationMethod.random
50
        )
51
        assert 0 < torch.count_nonzero(result) <= sample_tensor.view(-1).shape[0]
52

53
    def test_cpu_dtypes(self, sample_tensor):
54
        for dt in (torch.float16, torch.bfloat16, torch.float32):
55
            sparsify(
56
                tensor=sample_tensor.to(dtype=dt).cpu(),
57
                density=0.5,
58
                method=SparsificationMethod.rescaled_random,
59
            )
60

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.