paddlenlp

Форк
0
146 строк · 5.4 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16

17
import psutil
18
from predictor import CLSPredictor
19

20
from paddlenlp.utils.log import logger
21

22

23
def parse_args():
24
    parser = argparse.ArgumentParser()
25
    parser.add_argument(
26
        "--model_path_prefix", type=str, required=True, help="The path prefix of inference model to be used."
27
    )
28
    parser.add_argument(
29
        "--model_name_or_path", default="ernie-health-chinese", type=str, help="The directory or name of model."
30
    )
31
    parser.add_argument("--dataset", default="KUAKE-QIC", type=str, help="Dataset for text classfication.")
32
    parser.add_argument("--data_file", default=None, type=str, help="The data to predict with one sample per line.")
33
    parser.add_argument(
34
        "--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization."
35
    )
36
    parser.add_argument(
37
        "--use_fp16",
38
        action="store_true",
39
        help="Whether to use fp16 inference, only takes effect when deploying on gpu.",
40
    )
41
    parser.add_argument("--batch_size", default=200, type=int, help="Batch size per GPU/CPU for predicting.")
42
    parser.add_argument(
43
        "--num_threads", default=psutil.cpu_count(logical=False), type=int, help="num_threads for cpu."
44
    )
45
    parser.add_argument(
46
        "--device", choices=["cpu", "gpu"], default="gpu", help="Select which device to train model, defaults to gpu."
47
    )
48
    parser.add_argument("--device_id", default=0, help="Select which gpu device to train model.")
49
    args = parser.parse_args()
50
    return args
51

52

53
LABEL_LIST = {
54
    "kuake-qic": ["病情诊断", "治疗方案", "病因分析", "指标解读", "就医建议", "疾病表述", "后果表述", "注意事项", "功效作用", "医疗费用", "其他"],
55
    "kuake-qtr": ["完全不匹配", "很少匹配,有一些参考价值", "部分匹配", "完全匹配"],
56
    "kuake-qqr": ["B为A的语义父集,B指代范围大于A; 或者A与B语义毫无关联。", "B为A的语义子集,B指代范围小于A。", "表示A与B等价,表述完全一致。"],
57
    "chip-ctc": [
58
        "成瘾行为",
59
        "居住情况",
60
        "年龄",
61
        "酒精使用",
62
        "过敏耐受",
63
        "睡眠",
64
        "献血",
65
        "能力",
66
        "依存性",
67
        "知情同意",
68
        "数据可及性",
69
        "设备",
70
        "诊断",
71
        "饮食",
72
        "残疾群体",
73
        "疾病",
74
        "教育情况",
75
        "病例来源",
76
        "参与其它试验",
77
        "伦理审查",
78
        "种族",
79
        "锻炼",
80
        "性别",
81
        "健康群体",
82
        "实验室检查",
83
        "预期寿命",
84
        "读写能力",
85
        "含有多类别的语句",
86
        "肿瘤进展",
87
        "疾病分期",
88
        "护理",
89
        "口腔相关",
90
        "器官组织状态",
91
        "药物",
92
        "怀孕相关",
93
        "受体状态",
94
        "研究者决定",
95
        "风险评估",
96
        "性取向",
97
        "体征(医生检测)",
98
        " 吸烟状况",
99
        "特殊病人特征",
100
        "症状(患者感受)",
101
        "治疗或手术",
102
    ],
103
    "chip-sts": ["语义不同", "语义相同"],
104
    "chip-cdn-2c": ["否", "是"],
105
}
106

107
TEXT = {
108
    "kuake-qic": ["心肌缺血如何治疗与调养呢?", "什么叫痔核脱出?什么叫外痔?"],
109
    "kuake-qtr": [["儿童远视眼怎么恢复视力", "远视眼该如何保养才能恢复一些视力"], ["抗生素的药有哪些", "抗生素类的药物都有哪些?"]],
110
    "kuake-qqr": [["茴香是发物吗", "茴香怎么吃?"], ["气的胃疼是怎么回事", "气到胃痛是什么原因"]],
111
    "chip-ctc": ["(1)前牙结构发育不良:釉质发育不全、氟斑牙、四环素牙等;", "怀疑或确有酒精或药物滥用史;"],
112
    "chip-sts": [["糖尿病能吃减肥药吗?能治愈吗?", "糖尿病为什么不能吃减肥药"], ["H型高血压的定义", "WHO对高血压的最新分类定义标准数值"]],
113
    "chip-cdn-2c": [["1型糖尿病性植物神经病变", " 1型糖尿病肾病IV期"], ["髂腰肌囊性占位", "髂肌囊肿"]],
114
}
115

116
METRIC = {
117
    "kuake-qic": "acc",
118
    "kuake-qtr": "acc",
119
    "kuake-qqr": "acc",
120
    "chip-ctc": "macro",
121
    "chip-sts": "macro",
122
    "chip-cdn-2c": "macro",
123
}
124

125

126
def main():
127
    args = parse_args()
128

129
    for arg_name, arg_value in vars(args).items():
130
        logger.info("{:20}: {}".format(arg_name, arg_value))
131

132
    args.dataset = args.dataset.lower()
133
    label_list = LABEL_LIST[args.dataset]
134
    if args.data_file is not None:
135
        with open(args.data_file, "r") as fp:
136
            input_data = [x.strip().split("\t") for x in fp.readlines()]
137
            input_data = [x[0] if len(x) == 1 else x for x in input_data]
138
    else:
139
        input_data = TEXT[args.dataset]
140

141
    predictor = CLSPredictor(args, label_list)
142
    predictor.predict(input_data)
143

144

145
if __name__ == "__main__":
146
    main()
147

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.