pytorch

Форк
0
/
crf_viterbi_test.py 
42 строки · 1.6 Кб
1

2

3

4

5
from caffe2.python import workspace, crf
6

7
from caffe2.python.cnn import CNNModelHelper
8
from caffe2.python.crf_predict import crf_update_predictions
9
from caffe2.python.test_util import TestCase
10
import hypothesis.strategies as st
11
from hypothesis import given, settings
12
import numpy as np
13

14

15
class TestCrfDecode(TestCase):
16

17
    @given(num_tags=st.integers(2, 4), num_words=st.integers(2, 15))
18
    @settings(deadline=2000)
19
    def test_crf_viterbi(self, num_tags, num_words):
20
        model = CNNModelHelper(name='external')
21
        predictions = np.random.randn(num_words, num_tags).astype(np.float32)
22
        transitions = np.random.uniform(
23
            low=-1, high=1, size=(num_tags + 2, num_tags + 2)
24
        ).astype(np.float32)
25
        predictions_blob, transitions_blob = (
26
            model.net.AddExternalInputs('predictions', 'crf_transitions')
27
        )
28
        workspace.FeedBlob(str(transitions_blob), transitions)
29
        workspace.FeedBlob(str(predictions_blob), predictions)
30
        crf_layer = crf.CRFWithLoss(model, num_tags, transitions_blob)
31

32
        updated_predictions = crf_update_predictions(
33
            model, crf_layer, predictions_blob
34
        )
35
        ref_predictions = crf_layer.update_predictions(predictions_blob)
36

37
        workspace.RunNetOnce(model.param_init_net)
38
        workspace.RunNetOnce(model.net)
39

40
        updated_predictions = workspace.FetchBlob(str(updated_predictions))
41
        ref_predictions = workspace.FetchBlob(str(ref_predictions))
42
        np.testing.assert_allclose(
43
            updated_predictions,
44
            ref_predictions,
45
            atol=1e-4, rtol=1e-4, err_msg='Mismatch in CRF predictions'
46
        )
47

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.