paddlenlp

Форк
0
129 строк · 3.7 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import numpy as np
15

16
from paddlenlp import SimpleServer
17
from paddlenlp.server import BasePostHandler, TokenClsModelHandler
18

19
en_to_cn = {
20
    "bod": "身体",
21
    "mic": "微生物类",
22
    "dis": "疾病",
23
    "sym": "临床表现",
24
    "pro": "医疗程序",
25
    "equ": "医疗设备",
26
    "dru": "药物",
27
    "dep": "科室",
28
    "ite": "医学检验项目",
29
}
30

31
label_list = [
32
    [
33
        "B-bod",
34
        "I-bod",
35
        "E-bod",
36
        "S-bod",
37
        "B-dis",
38
        "I-dis",
39
        "E-dis",
40
        "S-dis",
41
        "B-pro",
42
        "I-pro",
43
        "E-pro",
44
        "S-pro",
45
        "B-dru",
46
        "I-dru",
47
        "E-dru",
48
        "S-dru",
49
        "B-ite",
50
        "I-ite",
51
        "E-ite",
52
        "S-ite",
53
        "B-mic",
54
        "I-mic",
55
        "E-mic",
56
        "S-mic",
57
        "B-equ",
58
        "I-equ",
59
        "E-equ",
60
        "S-equ",
61
        "B-dep",
62
        "I-dep",
63
        "E-dep",
64
        "S-dep",
65
        "O",
66
    ],
67
    ["B-sym", "I-sym", "E-sym", "S-sym", "O"],
68
]
69

70

71
def _extract_chunk(tokens):
72
    chunks = set()
73
    start_idx, cur_idx = 0, 0
74
    while cur_idx < len(tokens):
75
        if tokens[cur_idx][0] == "B":
76
            start_idx = cur_idx
77
            cur_idx += 1
78
            while cur_idx < len(tokens) and tokens[cur_idx][0] == "I":
79
                if tokens[cur_idx][2:] == tokens[start_idx][2:]:
80
                    cur_idx += 1
81
                else:
82
                    break
83
            if cur_idx < len(tokens) and tokens[cur_idx][0] == "E":
84
                if tokens[cur_idx][2:] == tokens[start_idx][2:]:
85
                    chunks.add((tokens[cur_idx][2:], start_idx - 1, cur_idx))
86
                    cur_idx += 1
87
        elif tokens[cur_idx][0] == "S":
88
            chunks.add((tokens[cur_idx][2:], cur_idx - 1, cur_idx))
89
            cur_idx += 1
90
        else:
91
            cur_idx += 1
92
    return list(chunks)
93

94

95
class NERPostHandler(BasePostHandler):
96
    def __init__(self):
97
        super().__init__()
98

99
    @classmethod
100
    def process(cls, data, parameters):
101
        if "logits" not in data or "logits_1" not in data:
102
            raise ValueError(
103
                "The output of model handler do not include the 'logits', "
104
                " please check the model handler output. The model handler output:\n{}".format(data)
105
            )
106
        tokens_oth = np.array(data["logits"])
107
        tokens_sym = np.array(data["logits_1"])
108
        tokens_oth = np.argmax(tokens_oth, axis=-1)
109
        tokens_sym = np.argmax(tokens_sym, axis=-1)
110
        entity = []
111
        for oth_ids, sym_ids in zip(tokens_oth, tokens_sym):
112
            token_oth = [label_list[0][x] for x in oth_ids]
113
            token_sym = [label_list[1][x] for x in sym_ids]
114
            chunks = _extract_chunk(token_oth) + _extract_chunk(token_sym)
115
            sub_entity = []
116
            for etype, sid, eid in chunks:
117
                sub_entity.append({"type": en_to_cn[etype], "start_id": sid, "end_id": eid})
118
            entity.append(sub_entity)
119
        return {"entity": entity}
120

121

122
app = SimpleServer()
123
app.register(
124
    "models/cblue_ner",
125
    model_path="../../../export_ner",
126
    tokenizer_name="ernie-health-chinese",
127
    model_handler=TokenClsModelHandler,
128
    post_handler=NERPostHandler,
129
)
130

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.