lm-evaluation-harness

Форк
0
152 строки · 5.3 Кб
1
import hashlib
2
import json
3
import os
4
import pickle
5
import unittest
6
from unittest.mock import patch
7

8
from lm_eval.api.instance import Instance
9
from lm_eval.models.gguf import GGUFLM
10

11

12
base_url = "https://matthoffner-ggml-llm-api.hf.space"
13

14

15
def gguf_completion_mock(base_url=None, **kwargs):
16
    # Generate a hash from the parameters
17
    hash_kwargs = {"base_url": base_url, **kwargs}
18
    hash = hashlib.sha256(
19
        json.dumps(hash_kwargs, sort_keys=True).encode("utf-8")
20
    ).hexdigest()
21

22
    fname = f"./tests/testdata/gguf_test_{hash}.pkl"
23

24
    if os.path.exists(fname):
25
        with open(fname, "rb") as fh:
26
            return pickle.load(fh)
27
    else:
28
        print("The file does not exist, attempting to write...")
29
        if "stop" in kwargs:
30
            result = {
31
                "choices": [
32
                    {
33
                        "text": f"generated text until {kwargs['stop']}",
34
                        "logprobs": {"token_logprobs": [-1.2345], "text_offset": 0},
35
                        "finish_reason": "length",
36
                    }
37
                ]
38
            }
39
        else:
40
            # generated with # curl -X 'POST'   'http://localhost:8000/v1/completions'   -H 'accept: application/json'   -H 'Content-Type: application/json'   -d '{"prompt": "string", "logprobs": 10, "temperature": 0.0, "max_tokens": 1, "echo": true}'
41
            result = {
42
                "id": "cmpl-4023976b-bc6a-43b0-a5a9-629f4216c7f3",
43
                "object": "text_completion",
44
                "created": 1700511361,
45
                "model": "../llama-2-7b.Q8_0.gguf",
46
                "choices": [
47
                    {
48
                        "text": "string(",
49
                        "index": 0,
50
                        "logprobs": {
51
                            "text_offset": [0, 7],
52
                            "token_logprobs": [None, -1.033263319857306],
53
                            "tokens": [" string", "("],
54
                            "top_logprobs": [
55
                                None,
56
                                {
57
                                    "(": -1.033263319857306,
58
                                    "[]": -2.6530743779017394,
59
                                    ".": -3.0377145947291324,
60
                                    "\n": -3.0399156750513976,
61
                                    "_": -3.510376089937872,
62
                                    " =": -3.6957918347193663,
63
                                    ",": -3.9309459866358702,
64
                                    " of": -4.2834550083949035,
65
                                    '("': -4.322762841112799,
66
                                    "()": -4.426229113466925,
67
                                },
68
                            ],
69
                        },
70
                        "finish_reason": "length",
71
                    }
72
                ],
73
                "usage": {
74
                    "prompt_tokens": 2,
75
                    "completion_tokens": 1,
76
                    "total_tokens": 3,
77
                },
78
            }
79

80
        try:
81
            os.makedirs(os.path.dirname(fname), exist_ok=True)
82
            print("Writing file at", fname)
83
            with open(fname, "wb") as fh:
84
                pickle.dump(result, fh)
85
            print("File written successfully")
86
        except Exception as e:
87
            print("File writing failed:", e)
88

89
        return result
90

91

92
class GGUFLMTest(unittest.TestCase):
93
    @patch(
94
        "lm_eval.models.gguf.GGUFLM.gguf_completion", side_effect=gguf_completion_mock
95
    )
96
    def test_loglikelihood(self, gguf_completion_mock):
97
        lm = GGUFLM(base_url)
98

99
        # Test loglikelihood
100
        requests = [
101
            Instance(
102
                request_type="loglikelihood",
103
                doc=args,
104
                arguments=args,
105
                idx=i,
106
            )
107
            for i, args in enumerate([("str", "ing"), ("str", "ing")])
108
        ]
109
        res = lm.loglikelihood(requests)
110

111
        # Assert the loglikelihood response is correct
112
        expected_res = [(logprob, True) for logprob in [0, 0]]
113
        self.assertEqual(res, expected_res)
114

115
    @patch(
116
        "lm_eval.models.gguf.GGUFLM.gguf_completion", side_effect=gguf_completion_mock
117
    )
118
    def test_generate_until(self, gguf_completion_mock):
119
        lm = GGUFLM(base_url)
120

121
        # Test generate_until
122
        requests = [
123
            Instance(
124
                request_type="generate_until",
125
                doc={"input": doc},
126
                arguments=(doc, {"until": stop}),
127
                idx=i,
128
            )
129
            for i, (doc, stop) in enumerate([("input1", "stop1"), ("input2", "stop2")])
130
        ]
131

132
        res = lm.generate_until(requests)
133

134
        # Assert the generate_until response is correct
135
        expected_res = ["generated text until stop1", "generated text until stop2"]
136
        self.assertEqual(res, expected_res)
137

138
    # @patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
139
    # def test_loglikelihood_rolling(self, gguf_completion_mock):
140
    #     lm = GGUFLM(base_url)
141

142
    #     # Test loglikelihood_rolling
143
    #     requests = ["input1", "input2"]
144
    #     res = lm.loglikelihood_rolling(requests)
145

146
    #     # Assert the loglikelihood_rolling response is correct
147
    #     expected_res = [(-1.2345, True), (-1.2345, True)]
148
    #     self.assertEqual(res, expected_res)
149

150

151
if __name__ == "__main__":
152
    unittest.main()
153

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.