paddlenlp
129 строк · 3.7 Кб
1# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import numpy as np15
16from paddlenlp import SimpleServer17from paddlenlp.server import BasePostHandler, TokenClsModelHandler18
19en_to_cn = {20"bod": "身体",21"mic": "微生物类",22"dis": "疾病",23"sym": "临床表现",24"pro": "医疗程序",25"equ": "医疗设备",26"dru": "药物",27"dep": "科室",28"ite": "医学检验项目",29}
30
31label_list = [32[33"B-bod",34"I-bod",35"E-bod",36"S-bod",37"B-dis",38"I-dis",39"E-dis",40"S-dis",41"B-pro",42"I-pro",43"E-pro",44"S-pro",45"B-dru",46"I-dru",47"E-dru",48"S-dru",49"B-ite",50"I-ite",51"E-ite",52"S-ite",53"B-mic",54"I-mic",55"E-mic",56"S-mic",57"B-equ",58"I-equ",59"E-equ",60"S-equ",61"B-dep",62"I-dep",63"E-dep",64"S-dep",65"O",66],67["B-sym", "I-sym", "E-sym", "S-sym", "O"],68]
69
70
71def _extract_chunk(tokens):72chunks = set()73start_idx, cur_idx = 0, 074while cur_idx < len(tokens):75if tokens[cur_idx][0] == "B":76start_idx = cur_idx77cur_idx += 178while cur_idx < len(tokens) and tokens[cur_idx][0] == "I":79if tokens[cur_idx][2:] == tokens[start_idx][2:]:80cur_idx += 181else:82break83if cur_idx < len(tokens) and tokens[cur_idx][0] == "E":84if tokens[cur_idx][2:] == tokens[start_idx][2:]:85chunks.add((tokens[cur_idx][2:], start_idx - 1, cur_idx))86cur_idx += 187elif tokens[cur_idx][0] == "S":88chunks.add((tokens[cur_idx][2:], cur_idx - 1, cur_idx))89cur_idx += 190else:91cur_idx += 192return list(chunks)93
94
95class NERPostHandler(BasePostHandler):96def __init__(self):97super().__init__()98
99@classmethod100def process(cls, data, parameters):101if "logits" not in data or "logits_1" not in data:102raise ValueError(103"The output of model handler do not include the 'logits', "104" please check the model handler output. The model handler output:\n{}".format(data)105)106tokens_oth = np.array(data["logits"])107tokens_sym = np.array(data["logits_1"])108tokens_oth = np.argmax(tokens_oth, axis=-1)109tokens_sym = np.argmax(tokens_sym, axis=-1)110entity = []111for oth_ids, sym_ids in zip(tokens_oth, tokens_sym):112token_oth = [label_list[0][x] for x in oth_ids]113token_sym = [label_list[1][x] for x in sym_ids]114chunks = _extract_chunk(token_oth) + _extract_chunk(token_sym)115sub_entity = []116for etype, sid, eid in chunks:117sub_entity.append({"type": en_to_cn[etype], "start_id": sid, "end_id": eid})118entity.append(sub_entity)119return {"entity": entity}120
121
122app = SimpleServer()123app.register(124"models/cblue_ner",125model_path="../../../export_ner",126tokenizer_name="ernie-health-chinese",127model_handler=TokenClsModelHandler,128post_handler=NERPostHandler,129)
130