pytorch
71 строка · 2.5 Кб
1
2
3
4
5
6import hypothesis7from hypothesis import given, settings, HealthCheck8import hypothesis.strategies as st9import numpy as np10
11from caffe2.python import core12import caffe2.python.hypothesis_test_util as hu13
14
15class TestSparseLpNorm(hu.HypothesisTestCase):16
17@staticmethod18def ref_lpnorm(param_in, p, reg_lambda):19"""Reference function that should be matched by the Caffe2 operator."""20if p == 2.0:21return param_in * (1 - reg_lambda)22if p == 1.0:23reg_term = np.ones_like(param_in) * reg_lambda * np.sign(param_in)24param_out = param_in - reg_term25param_out[np.abs(param_in) <= reg_lambda] = 0.26return param_out27raise ValueError28
29# Suppress filter_too_much health check.30# Likely caused by `assume` call falling through too often.31@settings(suppress_health_check=[HealthCheck.filter_too_much])32@given(inputs=hu.tensors(n=1, min_dim=2, max_dim=2),33p=st.integers(min_value=1, max_value=2),34reg_lambda=st.floats(min_value=1e-4, max_value=1e-1),35data_strategy=st.data(),36**hu.gcs_cpu_only)37def test_sparse_lpnorm(self, inputs, p, reg_lambda, data_strategy, gc, dc):38
39param, = inputs40param += 0.02 * np.sign(param)41param[param == 0.0] += 0.0242
43# Create an indexing array containing values that are lists of indices,44# which index into param45indices = data_strategy.draw(46hu.tensor(dtype=np.int64, min_dim=1, max_dim=1,47elements=st.sampled_from(np.arange(param.shape[0]))),48)49hypothesis.note('indices.shape: %s' % str(indices.shape))50
51# For now, the indices must be unique52hypothesis.assume(np.array_equal(np.unique(indices.flatten()),53np.sort(indices.flatten())))54
55op = core.CreateOperator(56"SparseLpRegularizer",57["param", "indices"],58["param"],59p=float(p),60reg_lambda=reg_lambda,61)62
63def ref_sparse_lp_regularizer(param, indices, grad=None):64param_out = np.copy(param)65for _, index in enumerate(indices):66param_out[index] = self.ref_lpnorm(67param[index],68p=p,69reg_lambda=reg_lambda,70)71return (param_out,)72
73self.assertReferenceChecks(74gc, op, [param, indices],75ref_sparse_lp_regularizer
76)77