lm-evaluation-harness

Форк
0
400 строк · 13.5 Кб
1
import itertools
2

3
import numpy as np
4
import pytest
5
import torch
6

7
from lm_eval.api.metrics import (
8
    aggregate_subtask_metrics,
9
    mean,
10
    pooled_sample_stderr,
11
    stderr_for_metric,
12
)
13
from lm_eval.models.utils import Collator
14
from lm_eval.utils import (
15
    get_rolling_token_windows,
16
    make_disjoint_window,
17
)
18

19

20
# noinspection DuplicatedCode
21
def test_get_rolling_token_windows_v1():
22
    gold = [
23
        ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
24
        (
25
            [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
26
            [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
27
        ),
28
        (
29
            [19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
30
            [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
31
        ),
32
        ([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]),
33
    ]
34
    x = list(range(34))
35
    generator = get_rolling_token_windows(
36
        token_list=x,
37
        prefix_token=-100,
38
        max_seq_len=10,
39
        context_len=1,
40
    )
41
    pred_length = 0
42
    output = []
43
    for input_tokens, pred_tokens in generator:
44
        output.append((input_tokens, pred_tokens))
45
        pred_length += len(pred_tokens)
46
    assert pred_length == len(x)
47
    assert gold == output
48

49

50
# noinspection DuplicatedCode
51
def test_get_rolling_token_windows_v2():
52
    gold = [
53
        ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
54
        ([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [10, 11, 12]),
55
        ([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [13, 14, 15]),
56
        ([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [16, 17, 18]),
57
        ([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [19, 20, 21]),
58
        ([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [22, 23, 24]),
59
        ([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [25, 26, 27]),
60
        ([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [28, 29, 30]),
61
        ([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [31, 32, 33]),
62
    ]
63
    x = list(range(34))
64
    generator = get_rolling_token_windows(
65
        token_list=x,
66
        prefix_token=-100,
67
        max_seq_len=10,
68
        context_len=8,
69
    )
70
    pred_length = 0
71
    output = []
72
    for input_tokens, pred_tokens in generator:
73
        output.append((input_tokens, pred_tokens))
74
        pred_length += len(pred_tokens)
75
    assert pred_length == len(x)
76
    assert gold == output
77

78

79
# noinspection DuplicatedCode
80
def test_get_rolling_token_windows_v3():
81
    gold = [
82
        ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
83
        ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10]),
84
        ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11]),
85
        ([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12]),
86
        ([3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [13]),
87
        ([4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [14]),
88
        ([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [15]),
89
        ([6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [16]),
90
        ([7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [17]),
91
        ([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18]),
92
        ([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [19]),
93
        ([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20]),
94
        ([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [21]),
95
        ([12, 13, 14, 15, 16, 17, 18, 19, 20, 21], [22]),
96
        ([13, 14, 15, 16, 17, 18, 19, 20, 21, 22], [23]),
97
        ([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [24]),
98
        ([15, 16, 17, 18, 19, 20, 21, 22, 23, 24], [25]),
99
        ([16, 17, 18, 19, 20, 21, 22, 23, 24, 25], [26]),
100
        ([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),
101
        ([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),
102
        ([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),
103
        ([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [30]),
104
        ([21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [31]),
105
        ([22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [32]),
106
        ([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [33]),
107
    ]
108
    x = list(range(34))
109
    generator = get_rolling_token_windows(
110
        token_list=x,
111
        prefix_token=-100,
112
        max_seq_len=10,
113
        context_len=10,
114
    )
115
    pred_length = 0
116
    output = []
117
    for input_tokens, pred_tokens in generator:
118
        output.append((input_tokens, pred_tokens))
119
        pred_length += len(pred_tokens)
120
    assert pred_length == len(x)
121
    assert gold == output
122

123

124
# noinspection DuplicatedCode
125
def test_get_rolling_token_windows_v4():
126
    gold = [
127
        ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
128
        ([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10]),
129
        ([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11]),
130
        ([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12]),
131
        ([3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [13]),
132
        ([4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [14]),
133
        ([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [15]),
134
        ([6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [16]),
135
        ([7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [17]),
136
        ([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18]),
137
        ([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [19]),
138
        ([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20]),
139
        ([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [21]),
140
        ([12, 13, 14, 15, 16, 17, 18, 19, 20, 21], [22]),
141
        ([13, 14, 15, 16, 17, 18, 19, 20, 21, 22], [23]),
142
        ([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [24]),
143
        ([15, 16, 17, 18, 19, 20, 21, 22, 23, 24], [25]),
144
        ([16, 17, 18, 19, 20, 21, 22, 23, 24, 25], [26]),
145
        ([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),
146
        ([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),
147
        ([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),
148
    ]
149
    x = list(range(30))
150
    generator = get_rolling_token_windows(
151
        token_list=x,
152
        prefix_token=-100,
153
        max_seq_len=10,
154
        context_len=10,
155
    )
156
    pred_length = 0
157
    output = []
158
    for input_tokens, pred_tokens in generator:
159
        output.append((input_tokens, pred_tokens))
160
        pred_length += len(pred_tokens)
161
    assert pred_length == len(x)
162
    assert gold == output
163

164

165
# noinspection DuplicatedCode
166
def test_get_rolling_token_windows_v5():
167
    gold = [
168
        ([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),
169
        (
170
            [9, 10, 11, 12, 13, 14, 15, 16, 17, 18],
171
            [10, 11, 12, 13, 14, 15, 16, 17, 18, 19],
172
        ),
173
        (
174
            [19, 20, 21, 22, 23, 24, 25, 26, 27, 28],
175
            [20, 21, 22, 23, 24, 25, 26, 27, 28, 29],
176
        ),
177
    ]
178
    x = list(range(30))
179
    generator = get_rolling_token_windows(
180
        token_list=x,
181
        prefix_token=-100,
182
        max_seq_len=10,
183
        context_len=1,
184
    )
185
    pred_length = 0
186
    output = []
187
    for input_tokens, pred_tokens in generator:
188
        output.append((input_tokens, pred_tokens))
189
        pred_length += len(pred_tokens)
190
    assert pred_length == len(x)
191
    assert gold == output
192

193

194
# noinspection DuplicatedCode
195
def test_get_rolling_token_windows_v6():
196
    gold = [
197
        ([-100, 0], [0, 1]),
198
        ([1, 2], [2, 3]),
199
        ([3, 4], [4, 5]),
200
        ([5, 6], [6, 7]),
201
        ([6, 7], [8]),
202
    ]
203
    x = list(range(9))
204
    generator = get_rolling_token_windows(
205
        token_list=x,
206
        prefix_token=-100,
207
        max_seq_len=2,
208
        context_len=1,
209
    )
210
    pred_length = 0
211
    output = []
212
    for input_tokens, pred_tokens in generator:
213
        output.append((input_tokens, pred_tokens))
214
        pred_length += len(pred_tokens)
215
    assert pred_length == len(x)
216
    assert gold == output
217

218

219
def test_get_rolling_token_windows_empty():
220
    generator = get_rolling_token_windows(
221
        token_list=[],
222
        prefix_token=-100,
223
        max_seq_len=2,
224
        context_len=1,
225
    )
226
    n = 0
227
    for _ in generator:
228
        n += 1
229
    assert n == 0
230

231

232
def test_make_disjoint_window():
233
    assert make_disjoint_window(([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])) == (
234
        [1],
235
        [2, 3, 4, 5, 6],
236
    )
237
    assert make_disjoint_window(([1, 2, 3, 4, 5], [4, 5, 6])) == ([1, 2, 3], [4, 5, 6])
238
    assert make_disjoint_window(([1, 2, 3, 4, 5], [6])) == ([1, 2, 3, 4, 5], [6])
239

240

241
class TestCollator:
242
    def make_generate_sample(self, end=10):
243
        strings = ["x" * i for i in range(1, end + 1)]
244
        gen_kwargs1, gen_kwargs2 = (
245
            {"temperature": 0},
246
            {"temperature": 0, "until": ["nn", "\n\n"]},
247
        )
248
        args = [
249
            (string, gen_kwargs1 if i < len(strings) // 2 else gen_kwargs2)
250
            for i, string in enumerate(strings)
251
        ]
252

253
        return args
254

255
    def make_loglikelihood_sample(self, end=11):
256
        samples = [
257
            (("x", "x"), list(range(1, total_length + 1)))
258
            for total_length in range(1, end + 1)
259
        ]
260
        return samples
261

262
    def make_loglikelihood_sample_group(self, end=11):
263
        a = [(("x", "x"), [1, 2, 3, 4, 5, 6, 7, 8], [x]) for x in range(9)]
264
        b = [
265
            (("x", "x"), [1, 2, 3, 4, 5, 6, 7, 8], [x, y, z])
266
            for x, y, z in zip(range(9), range(9, 18), range(18, 27))
267
        ]
268
        return a + b
269

270
    @pytest.mark.parametrize("batch_size, end", [(17, 30), (8, 61), (12, 48), (0, 9)])
271
    def test_generations(self, batch_size, end):
272
        _collate_gen = lambda x: (-len(x[0]), x[0])  # noqa: E731
273

274
        generation_samples = self.make_generate_sample(int(end))
275
        gens = Collator(generation_samples, _collate_gen, group_by="gen_kwargs")
276
        chunks = gens.get_batched(n=int(batch_size), batch_fn=None)
277
        output = []
278
        for chunks in chunks:
279
            # check batching
280
            group_one = end // 2
281
            group_two = end - end // 2
282
            assert (
283
                len(chunks) <= batch_size
284
                if batch_size != 0
285
                else len(chunks) in [group_one, group_two]
286
            )
287
            # check if reorder-er is working correctly
288
            assert all(
289
                len(chunks[i][0]) <= len(chunks[i - 1][0])
290
                for i in range(1, len(chunks))
291
            )
292
            # check if grouping correctly
293
            assert all(x[1] == chunks[0][1] for x in chunks)
294
            for x in chunks:
295
                output.append(x)
296
        reordered_output = gens.get_original(output)
297
        # check get original
298
        assert reordered_output == generation_samples
299

300
    @pytest.mark.parametrize("batch_size, end", [(17, 30), (8, 61), (12, 48), (0, 3)])
301
    def test_loglikelihood(self, batch_size, end):
302
        _collate_log = lambda x: (-len(x[1]), tuple(x[1]))  # noqa: E731
303
        loglikelihood_samples = self.make_loglikelihood_sample(int(end))
304
        loglikelihoods = Collator(
305
            loglikelihood_samples,
306
            _collate_log,
307
        )
308
        chunks = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None)
309
        output = []
310
        for chunks in chunks:
311
            # check batching
312
            assert len(chunks) <= batch_size if batch_size != 0 else len(chunks) == end
313
            # check reorder
314
            assert all(
315
                len(chunks[i][1]) <= len(chunks[i - 1][1])
316
                for i in range(1, len(chunks))
317
            )
318
            for x in chunks:
319
                output.append(x[1])
320
        # check indices
321
        reordered_output = loglikelihoods.get_original(output)
322
        assert reordered_output == [x[1] for x in loglikelihood_samples]
323

324
    @pytest.mark.parametrize("batch_size", [17, 8, 12, 0])
325
    def test_context_grouping(self, batch_size):
326
        def _collate(x):
327
            toks = x[1] + x[2]
328
            return -len(toks), tuple(toks)
329

330
        _collate_log = _collate  # noqa: E731
331
        loglikelihood_samples = self.make_loglikelihood_sample_group()
332
        loglikelihoods = Collator(
333
            loglikelihood_samples,
334
            _collate_log,
335
            group_fn=lambda a: a[-2] + a[-1][:-1],
336
            group_by="contexts",
337
        )
338
        chunks = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None)
339
        output = []
340
        outputs_ = []
341
        for chunks in chunks:
342
            # check batching
343
            if batch_size != 0:
344
                assert len(chunks) <= batch_size
345
            # check reorder
346
            assert all(
347
                len(chunks[i][1]) <= len(chunks[i - 1][1])
348
                for i in range(1, len(chunks))
349
            )
350
            for x in chunks:
351
                for request_str, cont_toks, logits in loglikelihoods.get_cache(
352
                    req_str="".join(x[0]),
353
                    cxt_toks=x[1],
354
                    cont_toks=x[2],
355
                    logits=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])
356
                    .unsqueeze(0)
357
                    .unsqueeze(0),
358
                ):
359
                    output.append(x[1])
360
                    outputs_.append(cont_toks)
361
        assert len(output) == len(outputs_)
362
        # check indices
363
        reordered_output = loglikelihoods.get_original(output)
364
        assert reordered_output == [x[1] for x in loglikelihood_samples]
365

366

367
def test_aggregate_mean():
368
    # test weight_by_size is respected
369
    assert (
370
        aggregate_subtask_metrics([0.3, 0.2, 0.4], [20, 40, 100], weight_by_size=False)
371
        == 0.3
372
    )
373
    assert (
374
        aggregate_subtask_metrics([0.3, 0.2, 0.4], [20, 40, 100], weight_by_size=True)
375
        == 0.3375
376
    )
377

378

379
@pytest.mark.parametrize(
380
    "samples",
381
    [
382
        [40 * [1.0] + 60 * [0.0], 30 * [1.0] + 30 * [0.0], 20 * [1.0] + 60 * [0.0]],
383
        [35 * [1.0] + 65 * [0.0], 20 * [1.0] + 20 * [0.0]],
384
    ],
385
)
386
def test_aggregate_stderrs(samples):
387
    # check that aggregating subtasks' bootstrap stderrs with our formula
388
    # (using weight_by_size) is ~equiv.
389
    # to just getting bootstrap stderr of the whole set of samples
390
    mean_stderr = stderr_for_metric(metric=mean, bootstrap_iters=100000)
391

392
    stderrs = [mean_stderr(subtask) for subtask in samples]
393

394
    sizes = [len(subtask) for subtask in samples]
395

396
    assert np.allclose(
397
        pooled_sample_stderr(stderrs, sizes),
398
        mean_stderr(list(itertools.chain.from_iterable(samples))),
399
        atol=1.0e-3,
400
    )
401

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.