18
from predictor import SPOPredictor
20
from paddlenlp.utils.log import logger
24
parser = argparse.ArgumentParser()
26
"--model_path_prefix", type=str, required=True, help="The path prefix of inference model to be used."
29
"--model_name_or_path", default="ernie-health-chinese", type=str, help="The directory or name of model."
31
parser.add_argument("--dataset", default="CMeIE", type=str, help="Dataset for named entity recognition.")
32
parser.add_argument("--data_file", default=None, type=str, help="The data to predict with one sample per line.")
34
"--max_seq_length", default=300, type=int, help="The maximum total input sequence length after tokenization."
39
help="Whether to use fp16 inference, only takes effect when deploying on gpu.",
42
"--num_threads", default=psutil.cpu_count(logical=False), type=int, help="num_threads for cpu."
44
parser.add_argument("--batch_size", default=20, type=int, help="Batch size per GPU/CPU for predicting.")
46
"--device", choices=["cpu", "gpu"], default="gpu", help="Select which device to train model, defaults to gpu."
48
parser.add_argument("--device_id", default=0, help="Select which gpu device to train model.")
49
args = parser.parse_args()
102
TEXT = {"cmeie": ["骶髂关节炎是明确诊断JAS的关键条件。若有肋椎关节病变会使胸部扩张度减小。", "稳定型缺血性心脏疾病@肥胖与缺乏活动也导致高血压增多。"]}
108
for arg_name, arg_value in vars(args).items():
109
logger.info("{:20}: {}".format(arg_name, arg_value))
111
dataset = args.dataset.lower()
112
label_list = LABEL_LIST[dataset]
113
if args.data_file is not None:
114
with open(args.data_file, "r") as fp:
115
input_data = [x.strip() for x in fp.readlines()]
117
input_data = TEXT[dataset]
119
predictor = SPOPredictor(args, label_list)
120
predictor.predict(input_data)
123
if __name__ == "__main__":