5
from caffe2.python import workspace, crf
7
from caffe2.python.cnn import CNNModelHelper
8
from caffe2.python.crf_predict import crf_update_predictions
9
from caffe2.python.test_util import TestCase
10
import hypothesis.strategies as st
11
from hypothesis import given, settings
15
class TestCrfDecode(TestCase):
17
@given(num_tags=st.integers(2, 4), num_words=st.integers(2, 15))
18
@settings(deadline=2000)
19
def test_crf_viterbi(self, num_tags, num_words):
20
model = CNNModelHelper(name='external')
21
predictions = np.random.randn(num_words, num_tags).astype(np.float32)
22
transitions = np.random.uniform(
23
low=-1, high=1, size=(num_tags + 2, num_tags + 2)
25
predictions_blob, transitions_blob = (
26
model.net.AddExternalInputs('predictions', 'crf_transitions')
28
workspace.FeedBlob(str(transitions_blob), transitions)
29
workspace.FeedBlob(str(predictions_blob), predictions)
30
crf_layer = crf.CRFWithLoss(model, num_tags, transitions_blob)
32
updated_predictions = crf_update_predictions(
33
model, crf_layer, predictions_blob
35
ref_predictions = crf_layer.update_predictions(predictions_blob)
37
workspace.RunNetOnce(model.param_init_net)
38
workspace.RunNetOnce(model.net)
40
updated_predictions = workspace.FetchBlob(str(updated_predictions))
41
ref_predictions = workspace.FetchBlob(str(ref_predictions))
42
np.testing.assert_allclose(
45
atol=1e-4, rtol=1e-4, err_msg='Mismatch in CRF predictions'