paddlenlp

Форк
0
124 строки · 3.9 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16

17
import psutil
18
from predictor import SPOPredictor
19

20
from paddlenlp.utils.log import logger
21

22

23
def parse_args():
24
    parser = argparse.ArgumentParser()
25
    parser.add_argument(
26
        "--model_path_prefix", type=str, required=True, help="The path prefix of inference model to be used."
27
    )
28
    parser.add_argument(
29
        "--model_name_or_path", default="ernie-health-chinese", type=str, help="The directory or name of model."
30
    )
31
    parser.add_argument("--dataset", default="CMeIE", type=str, help="Dataset for named entity recognition.")
32
    parser.add_argument("--data_file", default=None, type=str, help="The data to predict with one sample per line.")
33
    parser.add_argument(
34
        "--max_seq_length", default=300, type=int, help="The maximum total input sequence length after tokenization."
35
    )
36
    parser.add_argument(
37
        "--use_fp16",
38
        action="store_true",
39
        help="Whether to use fp16 inference, only takes effect when deploying on gpu.",
40
    )
41
    parser.add_argument(
42
        "--num_threads", default=psutil.cpu_count(logical=False), type=int, help="num_threads for cpu."
43
    )
44
    parser.add_argument("--batch_size", default=20, type=int, help="Batch size per GPU/CPU for predicting.")
45
    parser.add_argument(
46
        "--device", choices=["cpu", "gpu"], default="gpu", help="Select which device to train model, defaults to gpu."
47
    )
48
    parser.add_argument("--device_id", default=0, help="Select which gpu device to train model.")
49
    args = parser.parse_args()
50
    return args
51

52

53
LABEL_LIST = {
54
    "cmeie": [
55
        "预防",
56
        "阶段",
57
        "就诊科室",
58
        "辅助治疗",
59
        "化疗",
60
        "放射治疗",
61
        "手术治疗",
62
        "实验室检查",
63
        "影像学检查",
64
        "辅助检查",
65
        "组织学检查",
66
        "内窥镜检查",
67
        "筛查",
68
        "多发群体",
69
        "发病率",
70
        "发病年龄",
71
        "多发地区",
72
        "发病性别倾向",
73
        "死亡率",
74
        "多发季节",
75
        "传播途径",
76
        "并发症",
77
        "病理分型",
78
        "相关(导致)",
79
        "鉴别诊断",
80
        "相关(转化)",
81
        "相关(症状)",
82
        "临床表现",
83
        "治疗后症状",
84
        "侵及周围组织转移的症状",
85
        "病因",
86
        "高危因素",
87
        "风险评估因素",
88
        "病史",
89
        "遗传因素",
90
        "发病机制",
91
        "病理生理",
92
        "药物治疗",
93
        "发病部位",
94
        "转移部位",
95
        "外侵部位",
96
        "预后状况",
97
        "预后生存率",
98
        "同义词",
99
    ]
100
}
101

102
TEXT = {"cmeie": ["骶髂关节炎是明确诊断JAS的关键条件。若有肋椎关节病变会使胸部扩张度减小。", "稳定型缺血性心脏疾病@肥胖与缺乏活动也导致高血压增多。"]}
103

104

105
def main():
106
    args = parse_args()
107

108
    for arg_name, arg_value in vars(args).items():
109
        logger.info("{:20}: {}".format(arg_name, arg_value))
110

111
    dataset = args.dataset.lower()
112
    label_list = LABEL_LIST[dataset]
113
    if args.data_file is not None:
114
        with open(args.data_file, "r") as fp:
115
            input_data = [x.strip() for x in fp.readlines()]
116
    else:
117
        input_data = TEXT[dataset]
118

119
    predictor = SPOPredictor(args, label_list)
120
    predictor.predict(input_data)
121

122

123
if __name__ == "__main__":
124
    main()
125

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.