paddlenlp
142 строки · 4.4 Кб
1# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2#
3# Licensed under the Apache License, Version 2.0 (the "License");
4# you may not use this file except in compliance with the License.
5# You may obtain a copy of the License at
6#
7# http://www.apache.org/licenses/LICENSE-2.0
8#
9# Unless required by applicable law or agreed to in writing, software
10# distributed under the License is distributed on an "AS IS" BASIS,
11# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12# See the License for the specific language governing permissions and
13# limitations under the License.
14import numpy as np
15
16from paddlenlp import SimpleServer
17from paddlenlp.server import BasePostHandler, TokenClsModelHandler
18
19label_list = [
20"预防",
21"阶段",
22"就诊科室",
23"辅助治疗",
24"化疗",
25"放射治疗",
26"手术治疗",
27"实验室检查",
28"影像学检查",
29"辅助检查",
30"组织学检查",
31"内窥镜检查",
32"筛查",
33"多发群体",
34"发病率",
35"发病年龄",
36"多发地区",
37"发病性别倾向",
38"死亡率",
39"多发季节",
40"传播途径",
41"并发症",
42"病理分型",
43"相关(导致)",
44"鉴别诊断",
45"相关(转化)",
46"相关(症状)",
47"临床表现",
48"治疗后症状",
49"侵及周围组织转移的症状",
50"病因",
51"高危因素",
52"风险评估因素",
53"病史",
54"遗传因素",
55"发病机制",
56"病理生理",
57"药物治疗",
58"发病部位",
59"转移部位",
60"外侵部位",
61"预后状况",
62"预后生存率",
63"同义词",
64]
65
66
67class SPOPostHandler(BasePostHandler):
68def __init__(self):
69super().__init__()
70
71@classmethod
72def process(cls, data, parameters):
73if "logits" not in data or "logits_1" not in data:
74raise ValueError(
75"The output of model handler do not include the 'logits', "
76" please check the model handler output. The model handler output:\n{}".format(data)
77)
78lengths = np.array(data["attention_mask"], dtype="float32").sum(axis=-1)
79ent_logits = np.array(data["logits"])
80spo_logits = np.array(data["logits_1"])
81ent_pred_list = []
82ent_idxs_list = []
83for idx, ent_pred in enumerate(ent_logits):
84seq_len = lengths[idx] - 2
85start = np.where(ent_pred[:, 0] > 0.5)[0]
86end = np.where(ent_pred[:, 1] > 0.5)[0]
87ent_pred = []
88ent_idxs = {}
89for x in start:
90y = end[end >= x]
91if (x == 0) or (x > seq_len):
92continue
93if len(y) > 0:
94y = y[0]
95if y > seq_len:
96continue
97ent_idxs[x] = (x - 1, y - 1)
98ent_pred.append((x - 1, y - 1))
99ent_pred_list.append(ent_pred)
100ent_idxs_list.append(ent_idxs)
101
102spo_preds = spo_logits > 0
103spo_pred_list = [[] for _ in range(len(spo_preds))]
104idxs, preds, subs, objs = np.nonzero(spo_preds)
105for idx, p_id, s_id, o_id in zip(idxs, preds, subs, objs):
106obj = ent_idxs_list[idx].get(o_id, None)
107if obj is None:
108continue
109sub = ent_idxs_list[idx].get(s_id, None)
110if sub is None:
111continue
112spo_pred_list[idx].append((tuple(sub), p_id, tuple(obj)))
113input_data = data["data"]["text"]
114ent_list = []
115spo_list = []
116for i, (ent, rel) in enumerate(zip(ent_pred_list, spo_pred_list)):
117cur_ent_list = []
118cur_spo_list = []
119for sid, eid in ent:
120cur_ent_list.append("".join([str(d) for d in input_data[i][sid : eid + 1]]))
121for s, p, o in rel:
122cur_spo_list.append(
123(
124"".join([str(d) for d in input_data[i][s[0] : s[1] + 1]]),
125label_list[p],
126"".join([str(d) for d in input_data[i][o[0] : o[1] + 1]]),
127)
128)
129ent_list.append(cur_ent_list)
130spo_list.append(cur_spo_list)
131
132return {"entity": ent_list, "spo": spo_list}
133
134
135app = SimpleServer()
136app.register(
137"models/cblue_spo",
138model_path="../../../export",
139tokenizer_name="ernie-health-chinese",
140model_handler=TokenClsModelHandler,
141post_handler=SPOPostHandler,
142)
143