pytorch

Форк
0
/
sparse_lp_regularizer_test.py 
71 строка · 2.5 Кб
1

2

3

4

5

6
import hypothesis
7
from hypothesis import given, settings, HealthCheck
8
import hypothesis.strategies as st
9
import numpy as np
10

11
from caffe2.python import core
12
import caffe2.python.hypothesis_test_util as hu
13

14

15
class TestSparseLpNorm(hu.HypothesisTestCase):
16

17
    @staticmethod
18
    def ref_lpnorm(param_in, p, reg_lambda):
19
        """Reference function that should be matched by the Caffe2 operator."""
20
        if p == 2.0:
21
            return param_in * (1 - reg_lambda)
22
        if p == 1.0:
23
            reg_term = np.ones_like(param_in) * reg_lambda * np.sign(param_in)
24
            param_out = param_in - reg_term
25
            param_out[np.abs(param_in) <= reg_lambda] = 0.
26
            return param_out
27
        raise ValueError
28

29
    # Suppress filter_too_much health check.
30
    # Likely caused by `assume` call falling through too often.
31
    @settings(suppress_health_check=[HealthCheck.filter_too_much])
32
    @given(inputs=hu.tensors(n=1, min_dim=2, max_dim=2),
33
           p=st.integers(min_value=1, max_value=2),
34
           reg_lambda=st.floats(min_value=1e-4, max_value=1e-1),
35
           data_strategy=st.data(),
36
           **hu.gcs_cpu_only)
37
    def test_sparse_lpnorm(self, inputs, p, reg_lambda, data_strategy, gc, dc):
38

39
        param, = inputs
40
        param += 0.02 * np.sign(param)
41
        param[param == 0.0] += 0.02
42

43
        # Create an indexing array containing values that are lists of indices,
44
        # which index into param
45
        indices = data_strategy.draw(
46
            hu.tensor(dtype=np.int64, min_dim=1, max_dim=1,
47
                      elements=st.sampled_from(np.arange(param.shape[0]))),
48
        )
49
        hypothesis.note('indices.shape: %s' % str(indices.shape))
50

51
        # For now, the indices must be unique
52
        hypothesis.assume(np.array_equal(np.unique(indices.flatten()),
53
                                         np.sort(indices.flatten())))
54

55
        op = core.CreateOperator(
56
            "SparseLpRegularizer",
57
            ["param", "indices"],
58
            ["param"],
59
            p=float(p),
60
            reg_lambda=reg_lambda,
61
        )
62

63
        def ref_sparse_lp_regularizer(param, indices, grad=None):
64
            param_out = np.copy(param)
65
            for _, index in enumerate(indices):
66
                param_out[index] = self.ref_lpnorm(
67
                    param[index],
68
                    p=p,
69
                    reg_lambda=reg_lambda,
70
                )
71
            return (param_out,)
72

73
        self.assertReferenceChecks(
74
            gc, op, [param, indices],
75
            ref_sparse_lp_regularizer
76
        )
77

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.