paddlenlp

Форк
0
116 строк · 3.7 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16

17
import psutil
18
from predictor import NERPredictor
19

20
from paddlenlp.utils.log import logger
21

22

23
def parse_args():
24
    parser = argparse.ArgumentParser()
25
    parser.add_argument(
26
        "--model_path_prefix", type=str, required=True, help="The path prefix of inference model to be used."
27
    )
28
    parser.add_argument(
29
        "--model_name_or_path", default="ernie-health-chinese", type=str, help="The directory or name of model."
30
    )
31
    parser.add_argument("--dataset", default="CMeEE", type=str, help="Dataset for named entity recognition.")
32
    parser.add_argument("--data_file", default=None, type=str, help="The data to predict with one sample per line.")
33
    parser.add_argument(
34
        "--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization"
35
    )
36
    parser.add_argument(
37
        "--use_fp16",
38
        action="store_true",
39
        help="Whether to use fp16 inference, only takes effect when deploying on gpu.",
40
    )
41
    parser.add_argument("--batch_size", default=200, type=int, help="Batch size per GPU/CPU for predicting.")
42
    parser.add_argument(
43
        "--num_threads", default=psutil.cpu_count(logical=False), type=int, help="Number of threads for cpu."
44
    )
45
    parser.add_argument(
46
        "--device", choices=["cpu", "gpu"], default="gpu", help="Select which device to train model, defaults to gpu."
47
    )
48
    parser.add_argument("--device_id", default=0, help="Select which gpu device to train model.")
49
    args = parser.parse_args()
50
    return args
51

52

53
LABEL_LIST = {
54
    "cmeee": [
55
        [
56
            "B-bod",
57
            "I-bod",
58
            "E-bod",
59
            "S-bod",
60
            "B-dis",
61
            "I-dis",
62
            "E-dis",
63
            "S-dis",
64
            "B-pro",
65
            "I-pro",
66
            "E-pro",
67
            "S-pro",
68
            "B-dru",
69
            "I-dru",
70
            "E-dru",
71
            "S-dru",
72
            "B-ite",
73
            "I-ite",
74
            "E-ite",
75
            "S-ite",
76
            "B-mic",
77
            "I-mic",
78
            "E-mic",
79
            "S-mic",
80
            "B-equ",
81
            "I-equ",
82
            "E-equ",
83
            "S-equ",
84
            "B-dep",
85
            "I-dep",
86
            "E-dep",
87
            "S-dep",
88
            "O",
89
        ],
90
        ["B-sym", "I-sym", "E-sym", "S-sym", "O"],
91
    ]
92
}
93

94
TEXT = {"cmeee": ["研究证实,细胞减少与肺内病变程度及肺内炎性病变吸收程度密切相关。", "可为不规则发热、稽留热或弛张热,但以不规则发热为多,可能与患儿应用退热药物导致热型不规律有关。"]}
95

96

97
def main():
98
    args = parse_args()
99

100
    for arg_name, arg_value in vars(args).items():
101
        logger.info("{:20}: {}".format(arg_name, arg_value))
102

103
    dataset = args.dataset.lower()
104
    label_list = LABEL_LIST[dataset]
105
    if args.data_file is not None:
106
        with open(args.data_file, "r") as fp:
107
            input_data = [x.strip() for x in fp.readlines()]
108
    else:
109
        input_data = TEXT[dataset]
110

111
    predictor = NERPredictor(args, label_list)
112
    predictor.predict(input_data)
113

114

115
if __name__ == "__main__":
116
    main()
117

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.