paddlenlp

Форк
0
/
test_gradio.py 
147 строк · 4.6 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
from __future__ import annotations
15

16
import json
17
import socket
18
import subprocess
19
import sys
20
import time
21
import unittest
22

23
import pytest
24
import requests
25

26
from paddlenlp.transformers import LlamaTokenizer
27

28

29
def is_port_in_use(port):
30
    with socket.socket(socket.AF_INET, socket.SOCK_STREAM) as s:
31
        s.settimeout(0.01)
32
        try:
33
            s.bind(("localhost", port))
34
            return False
35
        except socket.error:
36
            return True
37

38

39
class UITest(unittest.TestCase):
40
    def setUp(self):
41
        # start web ui
42
        self.flask_port = self.avaliable_free_port()
43
        self.port = self.avaliable_free_port([self.flask_port])
44
        self.model_path = "__internal_testing__/micro-random-llama"
45
        command = 'cd llm && python flask_server.py --model_name_or_path {model_path} --port {port} --flask_port {flask_port} --src_length 1024 --dtype "float16"'.format(
46
            flask_port=self.flask_port, port=self.port, model_path=self.model_path
47
        )
48
        self.ui_process = subprocess.Popen(command, shell=True, stdout=sys.stdout, stderr=sys.stderr)
49
        self.tokenizer = LlamaTokenizer.from_pretrained(self.model_path)
50

51
        return super().setUp()
52

53
    def tearDown(self):
54
        self.ui_process.kill()
55

56
    def avaliable_free_port(self, exclude=None):
57
        exclude = exclude or []
58
        for port in range(8000, 10000):
59
            if port in exclude:
60
                continue
61
            if is_port_in_use(port):
62
                continue
63
            return port
64

65
        raise ValueError("can not get valiable port in [8000, 8200]")
66

67
    def wait_until_server_is_ready(self):
68
        while True:
69
            if is_port_in_use(self.flask_port) and is_port_in_use(self.port):
70
                break
71

72
            print("waiting for server ...")
73
            time.sleep(1)
74

75
    def get_gradio_ui_result(self, *args, **kwargs):
76
        _, _, file = self.client.predict(*args, **kwargs)
77

78
        with open(file, "r", encoding="utf-8") as f:
79
            content = json.load(f)
80
        return content[-1]["utterance"]
81

82
    @pytest.mark.timeout(4 * 60)
83
    def test_argument(self):
84
        self.wait_until_server_is_ready()
85

86
        def get_response(data):
87
            res = requests.post(f"http://localhost:{self.flask_port}/api/chat", json=data, stream=True)
88
            result_ = ""
89
            for line in res.iter_lines():
90
                print(line)
91
                result = json.loads(line)
92
                bot_response = result["result"]["response"]
93

94
                if bot_response["utterance"].endswith("[END]"):
95
                    bot_response["utterance"] = bot_response["utterance"][:-5]
96

97
                result_ += bot_response["utterance"]
98

99
            return result_
100

101
        data = {
102
            "context": "你好",
103
            "top_k": 1,
104
            "top_p": 1.0,
105
            "temperature": 1.0,
106
            "repetition_penalty": 1.0,
107
            "max_length": 20,
108
            "min_length": 1,
109
        }
110
        # Case 1: greedy search
111
        # result_0 = get_response(data)
112
        result_1 = get_response(data)
113

114
        # TODO(wj-Mcat): enable logit-comparision later
115
        # assert result_0 == result_1
116

117
        data = {
118
            "context": "你好",
119
            "top_k": 0,
120
            "top_p": 0.7,
121
            "temperature": 1.0,
122
            "repetition_penalty": 1.0,
123
            "max_length": 20,
124
            "min_length": 1,
125
        }
126

127
        # Case 2: sampling
128
        result_2 = get_response(data)
129
        # assert result_1 != result_2
130

131
        # 测试长度应该保持一致
132
        assert 10 <= len(self.tokenizer.tokenize(result_1)) <= 50
133
        assert 10 <= len(self.tokenizer.tokenize(result_2)) <= 50
134

135
        data = {
136
            "context": "你好",
137
            "top_k": 1,
138
            "top_p": 0.7,
139
            "temperature": 1.0,
140
            "repetition_penalty": 1.0,
141
            "max_length": 100,
142
            "min_length": 1,
143
        }
144
        # Case 3: max_length
145
        result_3 = get_response(data)
146
        assert result_3 != result_2
147
        assert 70 <= len(self.tokenizer.tokenize(result_3)) <= 150
148

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.