paddlenlp

Форк
0
54 строки · 2.5 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import json
17

18
import requests
19

20
parser = argparse.ArgumentParser()
21
parser.add_argument("--dataset", required=True, type=str, help="The dataset name for the simple seving")
22
parser.add_argument(
23
    "--max_seq_len", default=128, type=int, help="The maximum total input sequence length after tokenization."
24
)
25
parser.add_argument("--batch_size", default=1, type=int, help="Batch size per GPU/CPU for predicting.")
26
args = parser.parse_args()
27

28
url = "http://0.0.0.0:8189/models/cblue_cls"
29
headers = {"Content-Type": "application/json"}
30

31
TEXT = {
32
    "kuake-qic": ["心肌缺血如何治疗与调养呢?", "什么叫痔核脱出?什么叫外痔?"],
33
    "kuake-qtr": [["儿童远视眼怎么恢复视力", "远视眼该如何保养才能恢复一些视力"], ["抗生素的药有哪些", "抗生素类的药物都有哪些?"]],
34
    "kuake-qqr": [["茴香是发物吗", "茴香怎么吃?"], ["气的胃疼是怎么回事", "气到胃痛是什么原因"]],
35
    "chip-ctc": ["(1)前牙结构发育不良:釉质发育不全、氟斑牙、四环素牙等;", "怀疑或确有酒精或药物滥用史;"],
36
    "chip-sts": [["糖尿病能吃减肥药吗?能治愈吗?", "糖尿病为什么不能吃减肥药"], ["H型高血压的定义", "WHO对高血压的最新分类定义标准数值"]],
37
    "chip-cdn-2c": [["1型糖尿病性植物神经病变", " 1型糖尿病肾病IV期"], ["髂腰肌囊性占位", "髂肌囊肿"]],
38
}
39

40
if __name__ == "__main__":
41
    args.dataset = args.dataset.lower()
42
    input_data = TEXT[args.dataset]
43
    texts = []
44
    text_pairs = []
45
    for data in input_data:
46
        if len(data) == 2:
47
            text_pairs.append(data[1])
48
        texts.append(data[0])
49
    data = {
50
        "data": {"text": texts, "text_pair": text_pairs if len(text_pairs) > 0 else None},
51
        "parameters": {"max_seq_len": args.max_seq_len, "batch_size": args.batch_size},
52
    }
53
    r = requests.post(url=url, headers=headers, data=json.dumps(data))
54
    print(r.text)
55

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.