1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
18
from predictor import NERPredictor
20
from paddlenlp.utils.log import logger
24
parser = argparse.ArgumentParser()
26
"--model_path_prefix", type=str, required=True, help="The path prefix of inference model to be used."
29
"--model_name_or_path", default="ernie-health-chinese", type=str, help="The directory or name of model."
31
parser.add_argument("--dataset", default="CMeEE", type=str, help="Dataset for named entity recognition.")
32
parser.add_argument("--data_file", default=None, type=str, help="The data to predict with one sample per line.")
34
"--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization"
39
help="Whether to use fp16 inference, only takes effect when deploying on gpu.",
41
parser.add_argument("--batch_size", default=200, type=int, help="Batch size per GPU/CPU for predicting.")
43
"--num_threads", default=psutil.cpu_count(logical=False), type=int, help="Number of threads for cpu."
46
"--device", choices=["cpu", "gpu"], default="gpu", help="Select which device to train model, defaults to gpu."
48
parser.add_argument("--device_id", default=0, help="Select which gpu device to train model.")
49
args = parser.parse_args()
90
["B-sym", "I-sym", "E-sym", "S-sym", "O"],
94
TEXT = {"cmeee": ["研究证实,细胞减少与肺内病变程度及肺内炎性病变吸收程度密切相关。", "可为不规则发热、稽留热或弛张热,但以不规则发热为多,可能与患儿应用退热药物导致热型不规律有关。"]}
100
for arg_name, arg_value in vars(args).items():
101
logger.info("{:20}: {}".format(arg_name, arg_value))
103
dataset = args.dataset.lower()
104
label_list = LABEL_LIST[dataset]
105
if args.data_file is not None:
106
with open(args.data_file, "r") as fp:
107
input_data = [x.strip() for x in fp.readlines()]
109
input_data = TEXT[dataset]
111
predictor = NERPredictor(args, label_list)
112
predictor.predict(input_data)
115
if __name__ == "__main__":