datasets

Форк
0
/
test_metric_common.py 
219 строк · 7.7 Кб
1
# Copyright 2020 HuggingFace Inc.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import doctest
16
import glob
17
import importlib
18
import inspect
19
import os
20
import re
21
from contextlib import contextmanager
22
from functools import wraps
23
from unittest.mock import patch
24

25
import numpy as np
26
import pytest
27
from absl.testing import parameterized
28

29
import datasets
30
from datasets import load_metric
31

32
from .utils import for_all_test_methods, local, slow
33

34

35
# mark all tests as integration
36
pytestmark = pytest.mark.integration
37

38

39
REQUIRE_FAIRSEQ = {"comet"}
40
_has_fairseq = importlib.util.find_spec("fairseq") is not None
41

42
UNSUPPORTED_ON_WINDOWS = {"code_eval"}
43
_on_windows = os.name == "nt"
44

45
REQUIRE_TRANSFORMERS = {"bertscore", "frugalscore", "perplexity"}
46
_has_transformers = importlib.util.find_spec("transformers") is not None
47

48

49
def skip_if_metric_requires_fairseq(test_case):
50
    @wraps(test_case)
51
    def wrapper(self, metric_name):
52
        if not _has_fairseq and metric_name in REQUIRE_FAIRSEQ:
53
            self.skipTest('"test requires Fairseq"')
54
        else:
55
            test_case(self, metric_name)
56

57
    return wrapper
58

59

60
def skip_if_metric_requires_transformers(test_case):
61
    @wraps(test_case)
62
    def wrapper(self, metric_name):
63
        if not _has_transformers and metric_name in REQUIRE_TRANSFORMERS:
64
            self.skipTest('"test requires transformers"')
65
        else:
66
            test_case(self, metric_name)
67

68
    return wrapper
69

70

71
def skip_on_windows_if_not_windows_compatible(test_case):
72
    @wraps(test_case)
73
    def wrapper(self, metric_name):
74
        if _on_windows and metric_name in UNSUPPORTED_ON_WINDOWS:
75
            self.skipTest('"test not supported on Windows"')
76
        else:
77
            test_case(self, metric_name)
78

79
    return wrapper
80

81

82
def get_local_metric_names():
83
    metrics = [metric_dir.split(os.sep)[-2] for metric_dir in glob.glob("./metrics/*/")]
84
    return [{"testcase_name": x, "metric_name": x} for x in metrics if x != "gleu"]  # gleu is unfinished
85

86

87
@parameterized.named_parameters(get_local_metric_names())
88
@for_all_test_methods(
89
    skip_if_metric_requires_fairseq, skip_if_metric_requires_transformers, skip_on_windows_if_not_windows_compatible
90
)
91
@local
92
class LocalMetricTest(parameterized.TestCase):
93
    INTENSIVE_CALLS_PATCHER = {}
94
    metric_name = None
95

96
    @pytest.mark.filterwarnings("ignore:metric_module_factory is deprecated:FutureWarning")
97
    @pytest.mark.filterwarnings("ignore:load_metric is deprecated:FutureWarning")
98
    def test_load_metric(self, metric_name):
99
        doctest.ELLIPSIS_MARKER = "[...]"
100
        metric_module = importlib.import_module(
101
            datasets.load.metric_module_factory(os.path.join("metrics", metric_name)).module_path
102
        )
103
        metric = datasets.load.import_main_class(metric_module.__name__, dataset=False)
104
        # check parameters
105
        parameters = inspect.signature(metric._compute).parameters
106
        self.assertTrue(all(p.kind != p.VAR_KEYWORD for p in parameters.values()))  # no **kwargs
107
        # run doctest
108
        with self.patch_intensive_calls(metric_name, metric_module.__name__):
109
            with self.use_local_metrics():
110
                try:
111
                    results = doctest.testmod(metric_module, verbose=True, raise_on_error=True)
112
                except doctest.UnexpectedException as e:
113
                    raise e.exc_info[1]  # raise the exception that doctest caught
114
        self.assertEqual(results.failed, 0)
115
        self.assertGreater(results.attempted, 1)
116

117
    @slow
118
    def test_load_real_metric(self, metric_name):
119
        doctest.ELLIPSIS_MARKER = "[...]"
120
        metric_module = importlib.import_module(
121
            datasets.load.metric_module_factory(os.path.join("metrics", metric_name)).module_path
122
        )
123
        # run doctest
124
        with self.use_local_metrics():
125
            results = doctest.testmod(metric_module, verbose=True, raise_on_error=True)
126
        self.assertEqual(results.failed, 0)
127
        self.assertGreater(results.attempted, 1)
128

129
    @contextmanager
130
    def patch_intensive_calls(self, metric_name, module_name):
131
        if metric_name in self.INTENSIVE_CALLS_PATCHER:
132
            with self.INTENSIVE_CALLS_PATCHER[metric_name](module_name):
133
                yield
134
        else:
135
            yield
136

137
    @contextmanager
138
    def use_local_metrics(self):
139
        def load_local_metric(metric_name, *args, **kwargs):
140
            return load_metric(os.path.join("metrics", metric_name), *args, **kwargs)
141

142
        with patch("datasets.load_metric") as mock_load_metric:
143
            mock_load_metric.side_effect = load_local_metric
144
            yield
145

146
    @classmethod
147
    def register_intensive_calls_patcher(cls, metric_name):
148
        def wrapper(patcher):
149
            patcher = contextmanager(patcher)
150
            cls.INTENSIVE_CALLS_PATCHER[metric_name] = patcher
151
            return patcher
152

153
        return wrapper
154

155

156
# Metrics intensive calls patchers
157
# --------------------------------
158

159

160
@LocalMetricTest.register_intensive_calls_patcher("bleurt")
161
def patch_bleurt(module_name):
162
    import tensorflow.compat.v1 as tf
163
    from bleurt.score import Predictor
164

165
    tf.flags.DEFINE_string("sv", "", "")  # handle pytest cli flags
166

167
    class MockedPredictor(Predictor):
168
        def predict(self, input_dict):
169
            assert len(input_dict["input_ids"]) == 2
170
            return np.array([1.03, 1.04])
171

172
    # mock predict_fn which is supposed to do a forward pass with a bleurt model
173
    with patch("bleurt.score._create_predictor") as mock_create_predictor:
174
        mock_create_predictor.return_value = MockedPredictor()
175
        yield
176

177

178
@LocalMetricTest.register_intensive_calls_patcher("bertscore")
179
def patch_bertscore(module_name):
180
    import torch
181

182
    def bert_cos_score_idf(model, refs, *args, **kwargs):
183
        return torch.tensor([[1.0, 1.0, 1.0]] * len(refs))
184

185
    # mock get_model which is supposed to do download a bert model
186
    # mock bert_cos_score_idf which is supposed to do a forward pass with a bert model
187
    with patch("bert_score.scorer.get_model"), patch(
188
        "bert_score.scorer.bert_cos_score_idf"
189
    ) as mock_bert_cos_score_idf:
190
        mock_bert_cos_score_idf.side_effect = bert_cos_score_idf
191
        yield
192

193

194
@LocalMetricTest.register_intensive_calls_patcher("comet")
195
def patch_comet(module_name):
196
    def load_from_checkpoint(model_path):
197
        class Model:
198
            def predict(self, data, *args, **kwargs):
199
                assert len(data) == 2
200
                scores = [0.19, 0.92]
201
                return scores, sum(scores) / len(scores)
202

203
        return Model()
204

205
    # mock load_from_checkpoint which is supposed to do download a bert model
206
    # mock load_from_checkpoint which is supposed to do download a bert model
207
    with patch("comet.download_model") as mock_download_model:
208
        mock_download_model.return_value = None
209
        with patch("comet.load_from_checkpoint") as mock_load_from_checkpoint:
210
            mock_load_from_checkpoint.side_effect = load_from_checkpoint
211
            yield
212

213

214
def test_seqeval_raises_when_incorrect_scheme():
215
    metric = load_metric(os.path.join("metrics", "seqeval"))
216
    wrong_scheme = "ERROR"
217
    error_message = f"Scheme should be one of [IOB1, IOB2, IOE1, IOE2, IOBES, BILOU], got {wrong_scheme}"
218
    with pytest.raises(ValueError, match=re.escape(error_message)):
219
        metric.compute(predictions=[], references=[], scheme=wrong_scheme)
220

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.