18
from predictor import CLSPredictor
20
from paddlenlp.utils.log import logger
24
parser = argparse.ArgumentParser()
26
"--model_path_prefix", type=str, required=True, help="The path prefix of inference model to be used."
29
"--model_name_or_path", default="ernie-health-chinese", type=str, help="The directory or name of model."
31
parser.add_argument("--dataset", default="KUAKE-QIC", type=str, help="Dataset for text classfication.")
32
parser.add_argument("--data_file", default=None, type=str, help="The data to predict with one sample per line.")
34
"--max_seq_length", default=128, type=int, help="The maximum total input sequence length after tokenization."
39
help="Whether to use fp16 inference, only takes effect when deploying on gpu.",
41
parser.add_argument("--batch_size", default=200, type=int, help="Batch size per GPU/CPU for predicting.")
43
"--num_threads", default=psutil.cpu_count(logical=False), type=int, help="num_threads for cpu."
46
"--device", choices=["cpu", "gpu"], default="gpu", help="Select which device to train model, defaults to gpu."
48
parser.add_argument("--device_id", default=0, help="Select which gpu device to train model.")
49
args = parser.parse_args()
54
"kuake-qic": ["病情诊断", "治疗方案", "病因分析", "指标解读", "就医建议", "疾病表述", "后果表述", "注意事项", "功效作用", "医疗费用", "其他"],
55
"kuake-qtr": ["完全不匹配", "很少匹配,有一些参考价值", "部分匹配", "完全匹配"],
56
"kuake-qqr": ["B为A的语义父集,B指代范围大于A; 或者A与B语义毫无关联。", "B为A的语义子集,B指代范围小于A。", "表示A与B等价,表述完全一致。"],
103
"chip-sts": ["语义不同", "语义相同"],
104
"chip-cdn-2c": ["否", "是"],
108
"kuake-qic": ["心肌缺血如何治疗与调养呢?", "什么叫痔核脱出?什么叫外痔?"],
109
"kuake-qtr": [["儿童远视眼怎么恢复视力", "远视眼该如何保养才能恢复一些视力"], ["抗生素的药有哪些", "抗生素类的药物都有哪些?"]],
110
"kuake-qqr": [["茴香是发物吗", "茴香怎么吃?"], ["气的胃疼是怎么回事", "气到胃痛是什么原因"]],
111
"chip-ctc": ["(1)前牙结构发育不良:釉质发育不全、氟斑牙、四环素牙等;", "怀疑或确有酒精或药物滥用史;"],
112
"chip-sts": [["糖尿病能吃减肥药吗?能治愈吗?", "糖尿病为什么不能吃减肥药"], ["H型高血压的定义", "WHO对高血压的最新分类定义标准数值"]],
113
"chip-cdn-2c": [["1型糖尿病性植物神经病变", " 1型糖尿病肾病IV期"], ["髂腰肌囊性占位", "髂肌囊肿"]],
122
"chip-cdn-2c": "macro",
129
for arg_name, arg_value in vars(args).items():
130
logger.info("{:20}: {}".format(arg_name, arg_value))
132
args.dataset = args.dataset.lower()
133
label_list = LABEL_LIST[args.dataset]
134
if args.data_file is not None:
135
with open(args.data_file, "r") as fp:
136
input_data = [x.strip().split("\t") for x in fp.readlines()]
137
input_data = [x[0] if len(x) == 1 else x for x in input_data]
139
input_data = TEXT[args.dataset]
141
predictor = CLSPredictor(args, label_list)
142
predictor.predict(input_data)
145
if __name__ == "__main__":