3
from haystack import ComponentError, Document, Pipeline
4
from haystack.components.extractors import NamedEntityAnnotation, NamedEntityExtractor, NamedEntityExtractorBackend
10
"My name is Clara and I live in Berkeley, California.",
11
"I'm Merlin, the happy pig!",
12
"New York State declared a state of emergency after the announcement of the end of the world.",
13
"", # Intentionally empty.
21
NamedEntityAnnotation(entity="PER", start=11, end=16),
22
NamedEntityAnnotation(entity="LOC", start=31, end=39),
23
NamedEntityAnnotation(entity="LOC", start=41, end=51),
25
[NamedEntityAnnotation(entity="PER", start=4, end=10)],
26
[NamedEntityAnnotation(entity="LOC", start=0, end=14)],
32
def spacy_annotations():
35
NamedEntityAnnotation(entity="PERSON", start=11, end=16),
36
NamedEntityAnnotation(entity="GPE", start=31, end=39),
37
NamedEntityAnnotation(entity="GPE", start=41, end=51),
39
[NamedEntityAnnotation(entity="PERSON", start=4, end=10)],
40
[NamedEntityAnnotation(entity="GPE", start=0, end=14)],
45
def test_ner_extractor_init():
46
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
48
with pytest.raises(ComponentError, match=r"not initialized"):
49
extractor.run(documents=[])
51
assert not extractor.initialized
53
assert extractor.initialized
56
@pytest.mark.parametrize("batch_size", [1, 3])
57
def test_ner_extractor_hf_backend(raw_texts, hf_annotations, batch_size):
58
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER")
61
_extract_and_check_predictions(extractor, raw_texts, hf_annotations, batch_size)
64
@pytest.mark.parametrize("batch_size", [1, 3])
65
def test_ner_extractor_spacy_backend(raw_texts, spacy_annotations, batch_size):
66
extractor = NamedEntityExtractor(backend=NamedEntityExtractorBackend.SPACY, model="en_core_web_trf")
69
_extract_and_check_predictions(extractor, raw_texts, spacy_annotations, batch_size)
72
@pytest.mark.parametrize("batch_size", [1, 3])
73
def test_ner_extractor_in_pipeline(raw_texts, hf_annotations, batch_size):
75
pipeline.add_component(
77
instance=NamedEntityExtractor(backend=NamedEntityExtractorBackend.HUGGING_FACE, model="dslim/bert-base-NER"),
80
outputs = pipeline.run(
81
{"ner_extractor": {"documents": [Document(content=text) for text in raw_texts], "batch_size": batch_size}}
82
)["ner_extractor"]["documents"]
83
predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs]
84
_check_predictions(predicted, hf_annotations)
87
def _extract_and_check_predictions(extractor, texts, expected, batch_size):
88
docs = [Document(content=text) for text in texts]
89
outputs = extractor.run(documents=docs, batch_size=batch_size)["documents"]
90
assert all(id(a) == id(b) for a, b in zip(docs, outputs))
91
predicted = [NamedEntityExtractor.get_stored_annotations(doc) for doc in outputs]
93
_check_predictions(predicted, expected)
96
def _check_predictions(predicted, expected):
97
assert len(predicted) == len(expected)
98
for pred, exp in zip(predicted, expected):
99
assert len(pred) == len(exp)
101
for a, b in zip(pred, exp):
102
assert a.entity == b.entity
103
assert a.start == b.start
104
assert a.end == b.end