21
from contextlib import contextmanager
22
from functools import wraps
23
from unittest.mock import patch
27
from absl.testing import parameterized
30
from datasets import load_metric
32
from .utils import for_all_test_methods, local, slow
36
pytestmark = pytest.mark.integration
39
REQUIRE_FAIRSEQ = {"comet"}
40
_has_fairseq = importlib.util.find_spec("fairseq") is not None
42
UNSUPPORTED_ON_WINDOWS = {"code_eval"}
43
_on_windows = os.name == "nt"
45
REQUIRE_TRANSFORMERS = {"bertscore", "frugalscore", "perplexity"}
46
_has_transformers = importlib.util.find_spec("transformers") is not None
49
def skip_if_metric_requires_fairseq(test_case):
51
def wrapper(self, metric_name):
52
if not _has_fairseq and metric_name in REQUIRE_FAIRSEQ:
53
self.skipTest('"test requires Fairseq"')
55
test_case(self, metric_name)
60
def skip_if_metric_requires_transformers(test_case):
62
def wrapper(self, metric_name):
63
if not _has_transformers and metric_name in REQUIRE_TRANSFORMERS:
64
self.skipTest('"test requires transformers"')
66
test_case(self, metric_name)
71
def skip_on_windows_if_not_windows_compatible(test_case):
73
def wrapper(self, metric_name):
74
if _on_windows and metric_name in UNSUPPORTED_ON_WINDOWS:
75
self.skipTest('"test not supported on Windows"')
77
test_case(self, metric_name)
82
def get_local_metric_names():
83
metrics = [metric_dir.split(os.sep)[-2] for metric_dir in glob.glob("./metrics/*/")]
84
return [{"testcase_name": x, "metric_name": x} for x in metrics if x != "gleu"]
87
@parameterized.named_parameters(get_local_metric_names())
89
skip_if_metric_requires_fairseq, skip_if_metric_requires_transformers, skip_on_windows_if_not_windows_compatible
92
class LocalMetricTest(parameterized.TestCase):
93
INTENSIVE_CALLS_PATCHER = {}
96
@pytest.mark.filterwarnings("ignore:metric_module_factory is deprecated:FutureWarning")
97
@pytest.mark.filterwarnings("ignore:load_metric is deprecated:FutureWarning")
98
def test_load_metric(self, metric_name):
99
doctest.ELLIPSIS_MARKER = "[...]"
100
metric_module = importlib.import_module(
101
datasets.load.metric_module_factory(os.path.join("metrics", metric_name)).module_path
103
metric = datasets.load.import_main_class(metric_module.__name__, dataset=False)
105
parameters = inspect.signature(metric._compute).parameters
106
self.assertTrue(all(p.kind != p.VAR_KEYWORD for p in parameters.values()))
108
with self.patch_intensive_calls(metric_name, metric_module.__name__):
109
with self.use_local_metrics():
111
results = doctest.testmod(metric_module, verbose=True, raise_on_error=True)
112
except doctest.UnexpectedException as e:
114
self.assertEqual(results.failed, 0)
115
self.assertGreater(results.attempted, 1)
118
def test_load_real_metric(self, metric_name):
119
doctest.ELLIPSIS_MARKER = "[...]"
120
metric_module = importlib.import_module(
121
datasets.load.metric_module_factory(os.path.join("metrics", metric_name)).module_path
124
with self.use_local_metrics():
125
results = doctest.testmod(metric_module, verbose=True, raise_on_error=True)
126
self.assertEqual(results.failed, 0)
127
self.assertGreater(results.attempted, 1)
130
def patch_intensive_calls(self, metric_name, module_name):
131
if metric_name in self.INTENSIVE_CALLS_PATCHER:
132
with self.INTENSIVE_CALLS_PATCHER[metric_name](module_name):
138
def use_local_metrics(self):
139
def load_local_metric(metric_name, *args, **kwargs):
140
return load_metric(os.path.join("metrics", metric_name), *args, **kwargs)
142
with patch("datasets.load_metric") as mock_load_metric:
143
mock_load_metric.side_effect = load_local_metric
147
def register_intensive_calls_patcher(cls, metric_name):
148
def wrapper(patcher):
149
patcher = contextmanager(patcher)
150
cls.INTENSIVE_CALLS_PATCHER[metric_name] = patcher
160
@LocalMetricTest.register_intensive_calls_patcher("bleurt")
161
def patch_bleurt(module_name):
162
import tensorflow.compat.v1 as tf
163
from bleurt.score import Predictor
165
tf.flags.DEFINE_string("sv", "", "")
167
class MockedPredictor(Predictor):
168
def predict(self, input_dict):
169
assert len(input_dict["input_ids"]) == 2
170
return np.array([1.03, 1.04])
173
with patch("bleurt.score._create_predictor") as mock_create_predictor:
174
mock_create_predictor.return_value = MockedPredictor()
178
@LocalMetricTest.register_intensive_calls_patcher("bertscore")
179
def patch_bertscore(module_name):
182
def bert_cos_score_idf(model, refs, *args, **kwargs):
183
return torch.tensor([[1.0, 1.0, 1.0]] * len(refs))
187
with patch("bert_score.scorer.get_model"), patch(
188
"bert_score.scorer.bert_cos_score_idf"
189
) as mock_bert_cos_score_idf:
190
mock_bert_cos_score_idf.side_effect = bert_cos_score_idf
194
@LocalMetricTest.register_intensive_calls_patcher("comet")
195
def patch_comet(module_name):
196
def load_from_checkpoint(model_path):
198
def predict(self, data, *args, **kwargs):
199
assert len(data) == 2
200
scores = [0.19, 0.92]
201
return scores, sum(scores) / len(scores)
207
with patch("comet.download_model") as mock_download_model:
208
mock_download_model.return_value = None
209
with patch("comet.load_from_checkpoint") as mock_load_from_checkpoint:
210
mock_load_from_checkpoint.side_effect = load_from_checkpoint
214
def test_seqeval_raises_when_incorrect_scheme():
215
metric = load_metric(os.path.join("metrics", "seqeval"))
216
wrong_scheme = "ERROR"
217
error_message = f"Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {wrong_scheme}"
218
with pytest.raises(ValueError, match=re.escape(error_message)):
219
metric.compute(predictions=[], references=[], scheme=wrong_scheme)