lm-evaluation-harness
152 строки · 5.3 Кб
1import hashlib
2import json
3import os
4import pickle
5import unittest
6from unittest.mock import patch
7
8from lm_eval.api.instance import Instance
9from lm_eval.models.gguf import GGUFLM
10
11
12base_url = "https://matthoffner-ggml-llm-api.hf.space"
13
14
15def gguf_completion_mock(base_url=None, **kwargs):
16# Generate a hash from the parameters
17hash_kwargs = {"base_url": base_url, **kwargs}
18hash = hashlib.sha256(
19json.dumps(hash_kwargs, sort_keys=True).encode("utf-8")
20).hexdigest()
21
22fname = f"./tests/testdata/gguf_test_{hash}.pkl"
23
24if os.path.exists(fname):
25with open(fname, "rb") as fh:
26return pickle.load(fh)
27else:
28print("The file does not exist, attempting to write...")
29if "stop" in kwargs:
30result = {
31"choices": [
32{
33"text": f"generated text until {kwargs['stop']}",
34"logprobs": {"token_logprobs": [-1.2345], "text_offset": 0},
35"finish_reason": "length",
36}
37]
38}
39else:
40# generated with # curl -X 'POST' 'http://localhost:8000/v1/completions' -H 'accept: application/json' -H 'Content-Type: application/json' -d '{"prompt": "string", "logprobs": 10, "temperature": 0.0, "max_tokens": 1, "echo": true}'
41result = {
42"id": "cmpl-4023976b-bc6a-43b0-a5a9-629f4216c7f3",
43"object": "text_completion",
44"created": 1700511361,
45"model": "../llama-2-7b.Q8_0.gguf",
46"choices": [
47{
48"text": "string(",
49"index": 0,
50"logprobs": {
51"text_offset": [0, 7],
52"token_logprobs": [None, -1.033263319857306],
53"tokens": [" string", "("],
54"top_logprobs": [
55None,
56{
57"(": -1.033263319857306,
58"[]": -2.6530743779017394,
59".": -3.0377145947291324,
60"\n": -3.0399156750513976,
61"_": -3.510376089937872,
62" =": -3.6957918347193663,
63",": -3.9309459866358702,
64" of": -4.2834550083949035,
65'("': -4.322762841112799,
66"()": -4.426229113466925,
67},
68],
69},
70"finish_reason": "length",
71}
72],
73"usage": {
74"prompt_tokens": 2,
75"completion_tokens": 1,
76"total_tokens": 3,
77},
78}
79
80try:
81os.makedirs(os.path.dirname(fname), exist_ok=True)
82print("Writing file at", fname)
83with open(fname, "wb") as fh:
84pickle.dump(result, fh)
85print("File written successfully")
86except Exception as e:
87print("File writing failed:", e)
88
89return result
90
91
92class GGUFLMTest(unittest.TestCase):
93@patch(
94"lm_eval.models.gguf.GGUFLM.gguf_completion", side_effect=gguf_completion_mock
95)
96def test_loglikelihood(self, gguf_completion_mock):
97lm = GGUFLM(base_url)
98
99# Test loglikelihood
100requests = [
101Instance(
102request_type="loglikelihood",
103doc=args,
104arguments=args,
105idx=i,
106)
107for i, args in enumerate([("str", "ing"), ("str", "ing")])
108]
109res = lm.loglikelihood(requests)
110
111# Assert the loglikelihood response is correct
112expected_res = [(logprob, True) for logprob in [0, 0]]
113self.assertEqual(res, expected_res)
114
115@patch(
116"lm_eval.models.gguf.GGUFLM.gguf_completion", side_effect=gguf_completion_mock
117)
118def test_generate_until(self, gguf_completion_mock):
119lm = GGUFLM(base_url)
120
121# Test generate_until
122requests = [
123Instance(
124request_type="generate_until",
125doc={"input": doc},
126arguments=(doc, {"until": stop}),
127idx=i,
128)
129for i, (doc, stop) in enumerate([("input1", "stop1"), ("input2", "stop2")])
130]
131
132res = lm.generate_until(requests)
133
134# Assert the generate_until response is correct
135expected_res = ["generated text until stop1", "generated text until stop2"]
136self.assertEqual(res, expected_res)
137
138# @patch('lm_eval.models.gguf.GGUFLM.gguf_completion', side_effect=gguf_completion_mock)
139# def test_loglikelihood_rolling(self, gguf_completion_mock):
140# lm = GGUFLM(base_url)
141
142# # Test loglikelihood_rolling
143# requests = ["input1", "input2"]
144# res = lm.loglikelihood_rolling(requests)
145
146# # Assert the loglikelihood_rolling response is correct
147# expected_res = [(-1.2345, True), (-1.2345, True)]
148# self.assertEqual(res, expected_res)
149
150
151if __name__ == "__main__":
152unittest.main()
153