lm-evaluation-harness
400 строк · 13.5 Кб
1import itertools2
3import numpy as np4import pytest5import torch6
7from lm_eval.api.metrics import (8aggregate_subtask_metrics,9mean,10pooled_sample_stderr,11stderr_for_metric,12)
13from lm_eval.models.utils import Collator14from lm_eval.utils import (15get_rolling_token_windows,16make_disjoint_window,17)
18
19
20# noinspection DuplicatedCode
21def test_get_rolling_token_windows_v1():22gold = [23([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),24(25[9, 10, 11, 12, 13, 14, 15, 16, 17, 18],26[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],27),28(29[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],30[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],31),32([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [30, 31, 32, 33]),33]34x = list(range(34))35generator = get_rolling_token_windows(36token_list=x,37prefix_token=-100,38max_seq_len=10,39context_len=1,40)41pred_length = 042output = []43for input_tokens, pred_tokens in generator:44output.append((input_tokens, pred_tokens))45pred_length += len(pred_tokens)46assert pred_length == len(x)47assert gold == output48
49
50# noinspection DuplicatedCode
51def test_get_rolling_token_windows_v2():52gold = [53([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),54([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [10, 11, 12]),55([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [13, 14, 15]),56([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [16, 17, 18]),57([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [19, 20, 21]),58([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [22, 23, 24]),59([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [25, 26, 27]),60([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [28, 29, 30]),61([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [31, 32, 33]),62]63x = list(range(34))64generator = get_rolling_token_windows(65token_list=x,66prefix_token=-100,67max_seq_len=10,68context_len=8,69)70pred_length = 071output = []72for input_tokens, pred_tokens in generator:73output.append((input_tokens, pred_tokens))74pred_length += len(pred_tokens)75assert pred_length == len(x)76assert gold == output77
78
79# noinspection DuplicatedCode
80def test_get_rolling_token_windows_v3():81gold = [82([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),83([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10]),84([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11]),85([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12]),86([3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [13]),87([4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [14]),88([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [15]),89([6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [16]),90([7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [17]),91([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18]),92([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [19]),93([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20]),94([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [21]),95([12, 13, 14, 15, 16, 17, 18, 19, 20, 21], [22]),96([13, 14, 15, 16, 17, 18, 19, 20, 21, 22], [23]),97([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [24]),98([15, 16, 17, 18, 19, 20, 21, 22, 23, 24], [25]),99([16, 17, 18, 19, 20, 21, 22, 23, 24, 25], [26]),100([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),101([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),102([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),103([20, 21, 22, 23, 24, 25, 26, 27, 28, 29], [30]),104([21, 22, 23, 24, 25, 26, 27, 28, 29, 30], [31]),105([22, 23, 24, 25, 26, 27, 28, 29, 30, 31], [32]),106([23, 24, 25, 26, 27, 28, 29, 30, 31, 32], [33]),107]108x = list(range(34))109generator = get_rolling_token_windows(110token_list=x,111prefix_token=-100,112max_seq_len=10,113context_len=10,114)115pred_length = 0116output = []117for input_tokens, pred_tokens in generator:118output.append((input_tokens, pred_tokens))119pred_length += len(pred_tokens)120assert pred_length == len(x)121assert gold == output122
123
124# noinspection DuplicatedCode
125def test_get_rolling_token_windows_v4():126gold = [127([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),128([0, 1, 2, 3, 4, 5, 6, 7, 8, 9], [10]),129([1, 2, 3, 4, 5, 6, 7, 8, 9, 10], [11]),130([2, 3, 4, 5, 6, 7, 8, 9, 10, 11], [12]),131([3, 4, 5, 6, 7, 8, 9, 10, 11, 12], [13]),132([4, 5, 6, 7, 8, 9, 10, 11, 12, 13], [14]),133([5, 6, 7, 8, 9, 10, 11, 12, 13, 14], [15]),134([6, 7, 8, 9, 10, 11, 12, 13, 14, 15], [16]),135([7, 8, 9, 10, 11, 12, 13, 14, 15, 16], [17]),136([8, 9, 10, 11, 12, 13, 14, 15, 16, 17], [18]),137([9, 10, 11, 12, 13, 14, 15, 16, 17, 18], [19]),138([10, 11, 12, 13, 14, 15, 16, 17, 18, 19], [20]),139([11, 12, 13, 14, 15, 16, 17, 18, 19, 20], [21]),140([12, 13, 14, 15, 16, 17, 18, 19, 20, 21], [22]),141([13, 14, 15, 16, 17, 18, 19, 20, 21, 22], [23]),142([14, 15, 16, 17, 18, 19, 20, 21, 22, 23], [24]),143([15, 16, 17, 18, 19, 20, 21, 22, 23, 24], [25]),144([16, 17, 18, 19, 20, 21, 22, 23, 24, 25], [26]),145([17, 18, 19, 20, 21, 22, 23, 24, 25, 26], [27]),146([18, 19, 20, 21, 22, 23, 24, 25, 26, 27], [28]),147([19, 20, 21, 22, 23, 24, 25, 26, 27, 28], [29]),148]149x = list(range(30))150generator = get_rolling_token_windows(151token_list=x,152prefix_token=-100,153max_seq_len=10,154context_len=10,155)156pred_length = 0157output = []158for input_tokens, pred_tokens in generator:159output.append((input_tokens, pred_tokens))160pred_length += len(pred_tokens)161assert pred_length == len(x)162assert gold == output163
164
165# noinspection DuplicatedCode
166def test_get_rolling_token_windows_v5():167gold = [168([-100, 0, 1, 2, 3, 4, 5, 6, 7, 8], [0, 1, 2, 3, 4, 5, 6, 7, 8, 9]),169(170[9, 10, 11, 12, 13, 14, 15, 16, 17, 18],171[10, 11, 12, 13, 14, 15, 16, 17, 18, 19],172),173(174[19, 20, 21, 22, 23, 24, 25, 26, 27, 28],175[20, 21, 22, 23, 24, 25, 26, 27, 28, 29],176),177]178x = list(range(30))179generator = get_rolling_token_windows(180token_list=x,181prefix_token=-100,182max_seq_len=10,183context_len=1,184)185pred_length = 0186output = []187for input_tokens, pred_tokens in generator:188output.append((input_tokens, pred_tokens))189pred_length += len(pred_tokens)190assert pred_length == len(x)191assert gold == output192
193
194# noinspection DuplicatedCode
195def test_get_rolling_token_windows_v6():196gold = [197([-100, 0], [0, 1]),198([1, 2], [2, 3]),199([3, 4], [4, 5]),200([5, 6], [6, 7]),201([6, 7], [8]),202]203x = list(range(9))204generator = get_rolling_token_windows(205token_list=x,206prefix_token=-100,207max_seq_len=2,208context_len=1,209)210pred_length = 0211output = []212for input_tokens, pred_tokens in generator:213output.append((input_tokens, pred_tokens))214pred_length += len(pred_tokens)215assert pred_length == len(x)216assert gold == output217
218
219def test_get_rolling_token_windows_empty():220generator = get_rolling_token_windows(221token_list=[],222prefix_token=-100,223max_seq_len=2,224context_len=1,225)226n = 0227for _ in generator:228n += 1229assert n == 0230
231
232def test_make_disjoint_window():233assert make_disjoint_window(([1, 2, 3, 4, 5], [2, 3, 4, 5, 6])) == (234[1],235[2, 3, 4, 5, 6],236)237assert make_disjoint_window(([1, 2, 3, 4, 5], [4, 5, 6])) == ([1, 2, 3], [4, 5, 6])238assert make_disjoint_window(([1, 2, 3, 4, 5], [6])) == ([1, 2, 3, 4, 5], [6])239
240
241class TestCollator:242def make_generate_sample(self, end=10):243strings = ["x" * i for i in range(1, end + 1)]244gen_kwargs1, gen_kwargs2 = (245{"temperature": 0},246{"temperature": 0, "until": ["nn", "\n\n"]},247)248args = [249(string, gen_kwargs1 if i < len(strings) // 2 else gen_kwargs2)250for i, string in enumerate(strings)251]252
253return args254
255def make_loglikelihood_sample(self, end=11):256samples = [257(("x", "x"), list(range(1, total_length + 1)))258for total_length in range(1, end + 1)259]260return samples261
262def make_loglikelihood_sample_group(self, end=11):263a = [(("x", "x"), [1, 2, 3, 4, 5, 6, 7, 8], [x]) for x in range(9)]264b = [265(("x", "x"), [1, 2, 3, 4, 5, 6, 7, 8], [x, y, z])266for x, y, z in zip(range(9), range(9, 18), range(18, 27))267]268return a + b269
270@pytest.mark.parametrize("batch_size, end", [(17, 30), (8, 61), (12, 48), (0, 9)])271def test_generations(self, batch_size, end):272_collate_gen = lambda x: (-len(x[0]), x[0]) # noqa: E731273
274generation_samples = self.make_generate_sample(int(end))275gens = Collator(generation_samples, _collate_gen, group_by="gen_kwargs")276chunks = gens.get_batched(n=int(batch_size), batch_fn=None)277output = []278for chunks in chunks:279# check batching280group_one = end // 2281group_two = end - end // 2282assert (283len(chunks) <= batch_size284if batch_size != 0285else len(chunks) in [group_one, group_two]286)287# check if reorder-er is working correctly288assert all(289len(chunks[i][0]) <= len(chunks[i - 1][0])290for i in range(1, len(chunks))291)292# check if grouping correctly293assert all(x[1] == chunks[0][1] for x in chunks)294for x in chunks:295output.append(x)296reordered_output = gens.get_original(output)297# check get original298assert reordered_output == generation_samples299
300@pytest.mark.parametrize("batch_size, end", [(17, 30), (8, 61), (12, 48), (0, 3)])301def test_loglikelihood(self, batch_size, end):302_collate_log = lambda x: (-len(x[1]), tuple(x[1])) # noqa: E731303loglikelihood_samples = self.make_loglikelihood_sample(int(end))304loglikelihoods = Collator(305loglikelihood_samples,306_collate_log,307)308chunks = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None)309output = []310for chunks in chunks:311# check batching312assert len(chunks) <= batch_size if batch_size != 0 else len(chunks) == end313# check reorder314assert all(315len(chunks[i][1]) <= len(chunks[i - 1][1])316for i in range(1, len(chunks))317)318for x in chunks:319output.append(x[1])320# check indices321reordered_output = loglikelihoods.get_original(output)322assert reordered_output == [x[1] for x in loglikelihood_samples]323
324@pytest.mark.parametrize("batch_size", [17, 8, 12, 0])325def test_context_grouping(self, batch_size):326def _collate(x):327toks = x[1] + x[2]328return -len(toks), tuple(toks)329
330_collate_log = _collate # noqa: E731331loglikelihood_samples = self.make_loglikelihood_sample_group()332loglikelihoods = Collator(333loglikelihood_samples,334_collate_log,335group_fn=lambda a: a[-2] + a[-1][:-1],336group_by="contexts",337)338chunks = loglikelihoods.get_batched(n=int(batch_size), batch_fn=None)339output = []340outputs_ = []341for chunks in chunks:342# check batching343if batch_size != 0:344assert len(chunks) <= batch_size345# check reorder346assert all(347len(chunks[i][1]) <= len(chunks[i - 1][1])348for i in range(1, len(chunks))349)350for x in chunks:351for request_str, cont_toks, logits in loglikelihoods.get_cache(352req_str="".join(x[0]),353cxt_toks=x[1],354cont_toks=x[2],355logits=torch.tensor([1, 2, 3, 4, 5, 6, 7, 8])356.unsqueeze(0)357.unsqueeze(0),358):359output.append(x[1])360outputs_.append(cont_toks)361assert len(output) == len(outputs_)362# check indices363reordered_output = loglikelihoods.get_original(output)364assert reordered_output == [x[1] for x in loglikelihood_samples]365
366
367def test_aggregate_mean():368# test weight_by_size is respected369assert (370aggregate_subtask_metrics([0.3, 0.2, 0.4], [20, 40, 100], weight_by_size=False)371== 0.3372)373assert (374aggregate_subtask_metrics([0.3, 0.2, 0.4], [20, 40, 100], weight_by_size=True)375== 0.3375376)377
378
379@pytest.mark.parametrize(380"samples",381[382[40 * [1.0] + 60 * [0.0], 30 * [1.0] + 30 * [0.0], 20 * [1.0] + 60 * [0.0]],383[35 * [1.0] + 65 * [0.0], 20 * [1.0] + 20 * [0.0]],384],385)
386def test_aggregate_stderrs(samples):387# check that aggregating subtasks' bootstrap stderrs with our formula388# (using weight_by_size) is ~equiv.389# to just getting bootstrap stderr of the whole set of samples390mean_stderr = stderr_for_metric(metric=mean, bootstrap_iters=100000)391
392stderrs = [mean_stderr(subtask) for subtask in samples]393
394sizes = [len(subtask) for subtask in samples]395
396assert np.allclose(397pooled_sample_stderr(stderrs, sizes),398mean_stderr(list(itertools.chain.from_iterable(samples))),399atol=1.0e-3,400)401