1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
from __future__ import annotations
26
from paddlenlp.transformers import LlamaTokenizer
29
def is_port_in_use(port):
30
with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
33
s.bind(("localhost", port))
39
class UITest(unittest.TestCase):
42
self.flask_port = self.avaliable_free_port()
43
self.port = self.avaliable_free_port([self.flask_port])
44
self.model_path = "__internal_testing__/micro-random-llama"
45
command = 'cd llm && python flask_server.py --model_name_or_path {model_path} --port {port} --flask_port {flask_port} --src_length 1024 --dtype "float16"'.format(
46
flask_port=self.flask_port, port=self.port, model_path=self.model_path
48
self.ui_process = subprocess.Popen(command, shell=True, stdout=sys.stdout, stderr=sys.stderr)
49
self.tokenizer = LlamaTokenizer.from_pretrained(self.model_path)
51
return super().setUp()
54
self.ui_process.kill()
56
def avaliable_free_port(self, exclude=None):
57
exclude = exclude or []
58
for port in range(8000, 10000):
61
if is_port_in_use(port):
65
raise ValueError("can not get valiable port in [8000, 8200]")
67
def wait_until_server_is_ready(self):
69
if is_port_in_use(self.flask_port) and is_port_in_use(self.port):
72
print("waiting for server ...")
75
def get_gradio_ui_result(self, *args, **kwargs):
76
_, _, file = self.client.predict(*args, **kwargs)
78
with open(file, "r", encoding="utf-8") as f:
79
content = json.load(f)
80
return content[-1]["utterance"]
82
@pytest.mark.timeout(4 * 60)
83
def test_argument(self):
84
self.wait_until_server_is_ready()
86
def get_response(data):
87
res = requests.post(f"http://localhost:{self.flask_port}/api/chat", json=data, stream=True)
89
for line in res.iter_lines():
91
result = json.loads(line)
92
bot_response = result["result"]["response"]
94
if bot_response["utterance"].endswith("[END]"):
95
bot_response["utterance"] = bot_response["utterance"][:-5]
97
result_ += bot_response["utterance"]
106
"repetition_penalty": 1.0,
110
# Case 1: greedy search
111
# result_0 = get_response(data)
112
result_1 = get_response(data)
114
# TODO(wj-Mcat): enable logit-comparision later
115
# assert result_0 == result_1
122
"repetition_penalty": 1.0,
128
result_2 = get_response(data)
129
# assert result_1 != result_2
132
assert 10 <= len(self.tokenizer.tokenize(result_1)) <= 50
133
assert 10 <= len(self.tokenizer.tokenize(result_2)) <= 50
140
"repetition_penalty": 1.0,
145
result_3 = get_response(data)
146
assert result_3 != result_2
147
assert 70 <= len(self.tokenizer.tokenize(result_3)) <= 150