paddlenlp

Форк
0
105 строк · 4.8 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import json
17

18
from tqdm import tqdm
19

20
from paddlenlp import Taskflow
21

22

23
# yapf: disable
24
def parse_args():
25
    parser = argparse.ArgumentParser(__doc__)
26
    parser.add_argument('--model_path', type=str, default=None, help='the model path to be loaded for question_generation taskflow')
27
    parser.add_argument('--source_file_path', type=str, default=None, help='the source file path')
28
    parser.add_argument('--target_file_path', type=str, default=None, help='the target json file path')
29
    parser.add_argument('--all_sample_num', type=int, default=None, help='the test sample number when convert_json_to_data')
30
    parser.add_argument('--num_return_sequences', type=int, default=3, help='the number of return sequences for each input sample, it should be less than num_beams')
31
    parser.add_argument('--batch_size', type=int, default=1, help='the batch size when using taskflow')
32
    parser.add_argument('--position_prob', type=float, default=0.01, help='the batch size when using taskflow')
33
    parser.add_argument('--decode_strategy', type=str, default=None, help='the decode strategy')
34
    parser.add_argument('--num_beams', type=int, default=6, help='the number of beams when using beam search')
35
    parser.add_argument('--num_beam_groups', type=int, default=1, help='the number of beam groups when using diverse beam search')
36
    parser.add_argument('--diversity_rate', type=float, default=0.0, help='the diversity_rate when using diverse beam search')
37
    parser.add_argument('--top_k', type=float, default=0, help='the top_k when using sampling decoding strategy')
38
    parser.add_argument('--top_p', type=float, default=1.0, help='the top_p when using sampling decoding strategy')
39
    parser.add_argument('--temperature', type=float, default=1.0, help='the temperature when using sampling decoding strategy')
40
    args = parser.parse_args()
41
    return args
42
# yapf: enable
43

44

45
def answer_generation_from_paragraphs(paragraphs, batch_size=16, model=None, wf=None):
46
    """Generate answer from given paragraphs."""
47
    result = []
48
    buffer = []
49
    for paragraph_tobe in tqdm(paragraphs):
50
        buffer.append(paragraph_tobe)
51
        if len(buffer) == batch_size:
52
            predicts = model(buffer)
53
            paragraph_list = buffer
54
            buffer = []
55
            for predict_dict, paragraph in zip(predicts, paragraph_list):
56
                if "答案" in predict_dict:
57
                    answer_dicts = predict_dict["答案"]
58
                    answers = [answer_dict["text"] for answer_dict in answer_dicts]
59
                    probabilitys = [answer_dict["probability"] for answer_dict in answer_dicts]
60
                else:
61
                    answers = []
62
                    probabilitys = []
63

64
                outdict = {
65
                    "context": paragraph,
66
                    "answer_candidates": sorted([(a, p) for a, p in zip(answers, probabilitys)], key=lambda x: -x[1]),
67
                }
68
                if wf:
69
                    wf.write(json.dumps(outdict, ensure_ascii=False) + "\n")
70
                result.append(outdict)
71
    return result
72

73

74
if __name__ == "__main__":
75
    args = parse_args()
76
    schema = ["答案"]
77
    answer_generator = Taskflow(
78
        "information_extraction",
79
        schema=schema,
80
        task_path=args.model_path,
81
        batch_size=args.batch_size,
82
        position_prob=args.position_prob,
83
    )
84
    assert args.source_file_path
85
    paragraphs = []
86
    if args.source_file_path.endswith(".json"):
87
        with open(args.source_file_path, "r", encoding="utf-8") as rf:
88
            for json_line in rf:
89
                line_dict = json.loads(json_line)
90
                assert "context" in line_dict or "content" in line_dict
91
                if "context" in line_dict:
92
                    paragraphs.append(line_dict["context"].strip())
93
                elif "content" in line_dict:
94
                    paragraphs.append(line_dict["content"].strip())
95
    else:
96
        with open(args.source_file_path, "r", encoding="utf-8") as rf:
97
            for line in rf:
98
                paragraphs.append(line.strip())
99
    wf = None
100
    if args.target_file_path:
101
        wf = open(args.target_file_path, "w", encoding="utf-8")
102

103
    answer_generation_from_paragraphs(paragraphs, batch_size=args.batch_size, model=answer_generator, wf=wf)
104
    rf.close()
105
    wf.close()
106

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.