haystack

Форк
0
/
test_named_entity_extractor.py 
104 строки · 3.7 Кб
1
import pytest
2

3
from haystack import ComponentError, Document, Pipeline
4
from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend
5

6

7
@pytest.fixture
8
def raw_texts():
9
    return [
10
        "My name is Clara and I live in Berkeley, California.",
11
        "I'm Merlin, the happy pig!",
12
        "New York State declared a state of emergency after the announcement of the end of the world.",
13
        "",  # Intentionally empty.
14
    ]
15

16

17
@pytest.fixture
18
def hf_annotations():
19
    return [
20
        [
21
            NamedEntityAnnotation(entity="PER", start=11, end=16),
22
            NamedEntityAnnotation(entity="LOC", start=31, end=39),
23
            NamedEntityAnnotation(entity="LOC", start=41, end=51),
24
        ],
25
        [NamedEntityAnnotation(entity="PER", start=4, end=10)],
26
        [NamedEntityAnnotation(entity="LOC", start=0, end=14)],
27
        [],
28
    ]
29

30

31
@pytest.fixture
32
def spacy_annotations():
33
    return [
34
        [
35
            NamedEntityAnnotation(entity="PERSON", start=11, end=16),
36
            NamedEntityAnnotation(entity="GPE", start=31, end=39),
37
            NamedEntityAnnotation(entity="GPE", start=41, end=51),
38
        ],
39
        [NamedEntityAnnotation(entity="PERSON", start=4, end=10)],
40
        [NamedEntityAnnotation(entity="GPE", start=0, end=14)],
41
        [],
42
    ]
43

44

45
def test_ner_extractor_init():
46
    extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
47

48
    with pytest.raises(ComponentError, match=r"not initialized"):
49
        extractor.run(documents=[])
50

51
    assert not extractor.initialized
52
    extractor.warm_up()
53
    assert extractor.initialized
54

55

56
@pytest.mark.parametrize("batch_size", [1, 3])
57
def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
58
    extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
59
    extractor.warm_up()
60

61
    _extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)
62

63

64
@pytest.mark.parametrize("batch_size", [1, 3])
65
def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size):
66
    extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_trf")
67
    extractor.warm_up()
68

69
    _extract_and_check_predictions(extractor, raw_texts, spacy_annotations, batch_size)
70

71

72
@pytest.mark.parametrize("batch_size", [1, 3])
73
def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size):
74
    pipeline = Pipeline()
75
    pipeline.add_component(
76
        name="ner_extractor",
77
        instance=NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER"),
78
    )
79

80
    outputs = pipeline.run(
81
        {"ner_extractor": {"documents": [Document(content=text) for text in raw_texts], "batch_size": batch_size}}
82
    )["ner_extractor"]["documents"]
83
    predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs]
84
    _check_predictions(predicted, hf_annotations)
85

86

87
def _extract_and_check_predictions(extractor, texts, expected, batch_size):
88
    docs = [Document(content=text) for text in texts]
89
    outputs = extractor.run(documents=docs, batch_size=batch_size)["documents"]
90
    assert all(id(a) == id(b) for a, b in zip(docs, outputs))
91
    predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs]
92

93
    _check_predictions(predicted, expected)
94

95

96
def _check_predictions(predicted, expected):
97
    assert len(predicted) == len(expected)
98
    for pred, exp in zip(predicted, expected):
99
        assert len(pred) == len(exp)
100

101
        for a, b in zip(pred, exp):
102
            assert a.entity == b.entity
103
            assert a.start == b.start
104
            assert a.end == b.end
105

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.