paddlenlp

Форк
0
142 строки · 4.4 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import numpy as np
15

16
from paddlenlp import SimpleServer
17
from paddlenlp.server import BasePostHandler, TokenClsModelHandler
18

19
label_list = [
20
    "预防",
21
    "阶段",
22
    "就诊科室",
23
    "辅助治疗",
24
    "化疗",
25
    "放射治疗",
26
    "手术治疗",
27
    "实验室检查",
28
    "影像学检查",
29
    "辅助检查",
30
    "组织学检查",
31
    "内窥镜检查",
32
    "筛查",
33
    "多发群体",
34
    "发病率",
35
    "发病年龄",
36
    "多发地区",
37
    "发病性别倾向",
38
    "死亡率",
39
    "多发季节",
40
    "传播途径",
41
    "并发症",
42
    "病理分型",
43
    "相关(导致)",
44
    "鉴别诊断",
45
    "相关(转化)",
46
    "相关(症状)",
47
    "临床表现",
48
    "治疗后症状",
49
    "侵及周围组织转移的症状",
50
    "病因",
51
    "高危因素",
52
    "风险评估因素",
53
    "病史",
54
    "遗传因素",
55
    "发病机制",
56
    "病理生理",
57
    "药物治疗",
58
    "发病部位",
59
    "转移部位",
60
    "外侵部位",
61
    "预后状况",
62
    "预后生存率",
63
    "同义词",
64
]
65

66

67
class SPOPostHandler(BasePostHandler):
68
    def __init__(self):
69
        super().__init__()
70

71
    @classmethod
72
    def process(cls, data, parameters):
73
        if "logits" not in data or "logits_1" not in data:
74
            raise ValueError(
75
                "The output of model handler do not include the 'logits', "
76
                " please check the model handler output. The model handler output:\n{}".format(data)
77
            )
78
        lengths = np.array(data["attention_mask"], dtype="float32").sum(axis=-1)
79
        ent_logits = np.array(data["logits"])
80
        spo_logits = np.array(data["logits_1"])
81
        ent_pred_list = []
82
        ent_idxs_list = []
83
        for idx, ent_pred in enumerate(ent_logits):
84
            seq_len = lengths[idx] - 2
85
            start = np.where(ent_pred[:, 0] > 0.5)[0]
86
            end = np.where(ent_pred[:, 1] > 0.5)[0]
87
            ent_pred = []
88
            ent_idxs = {}
89
            for x in start:
90
                y = end[end >= x]
91
                if (x == 0) or (x > seq_len):
92
                    continue
93
                if len(y) > 0:
94
                    y = y[0]
95
                    if y > seq_len:
96
                        continue
97
                    ent_idxs[x] = (x - 1, y - 1)
98
                    ent_pred.append((x - 1, y - 1))
99
            ent_pred_list.append(ent_pred)
100
            ent_idxs_list.append(ent_idxs)
101

102
        spo_preds = spo_logits > 0
103
        spo_pred_list = [[] for _ in range(len(spo_preds))]
104
        idxs, preds, subs, objs = np.nonzero(spo_preds)
105
        for idx, p_id, s_id, o_id in zip(idxs, preds, subs, objs):
106
            obj = ent_idxs_list[idx].get(o_id, None)
107
            if obj is None:
108
                continue
109
            sub = ent_idxs_list[idx].get(s_id, None)
110
            if sub is None:
111
                continue
112
            spo_pred_list[idx].append((tuple(sub), p_id, tuple(obj)))
113
        input_data = data["data"]["text"]
114
        ent_list = []
115
        spo_list = []
116
        for i, (ent, rel) in enumerate(zip(ent_pred_list, spo_pred_list)):
117
            cur_ent_list = []
118
            cur_spo_list = []
119
            for sid, eid in ent:
120
                cur_ent_list.append("".join([str(d) for d in input_data[i][sid : eid + 1]]))
121
            for s, p, o in rel:
122
                cur_spo_list.append(
123
                    (
124
                        "".join([str(d) for d in input_data[i][s[0] : s[1] + 1]]),
125
                        label_list[p],
126
                        "".join([str(d) for d in input_data[i][o[0] : o[1] + 1]]),
127
                    )
128
                )
129
            ent_list.append(cur_ent_list)
130
            spo_list.append(cur_spo_list)
131

132
        return {"entity": ent_list, "spo": spo_list}
133

134

135
app = SimpleServer()
136
app.register(
137
    "models/cblue_spo",
138
    model_path="../../../export",
139
    tokenizer_name="ernie-health-chinese",
140
    model_handler=TokenClsModelHandler,
141
    post_handler=SPOPostHandler,
142
)
143

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.