paddlenlp

Форк
0
170 строк · 5.9 Кб
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import argparse
17
import time
18
from pprint import pprint
19

20
import paddle
21
import torch
22
from transformers import BartForConditionalGeneration as hf_bart_model
23

24
from paddlenlp.data import Pad
25
from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer
26

27

28
def prepare_input(tokenizer, sentences):
29
    word_pad = Pad(tokenizer.pad_token_id, dtype="int64")
30
    tokenized = tokenizer(sentences)
31
    inputs = word_pad([i["input_ids"] for i in tokenized])
32
    input_ids = paddle.to_tensor(inputs)
33
    return input_ids
34

35

36
def parse_args():
37
    parser = argparse.ArgumentParser()
38
    parser.add_argument(
39
        "--model_name_or_path",
40
        default="bart-base",
41
        type=str,
42
        choices=["bart-base", "bart-large"],
43
        help="The model name to specify the bart to use. Can be one of ['bart-base', 'bart-large']. ",
44
    )
45
    parser.add_argument(
46
        "--decode_strategy",
47
        default="sampling",
48
        type=str,
49
        choices=["greedy_search", "beam_search", "sampling"],
50
        help="The decoding strategy. Can be one of ['greedy_search', 'beam_search', 'sampling']",
51
    )
52
    parser.add_argument("--num_beams", default=4, type=int, help="The parameters for beam search. ")
53
    parser.add_argument("--top_k", default=4, type=int, help="The number of candidate to procedure beam search. ")
54
    parser.add_argument(
55
        "--top_p", default=1.0, type=float, help="The probability threshold to procedure topp sampling. "
56
    )
57
    parser.add_argument("--max_length", default=32, type=int, help="Maximum output length. ")
58
    parser.add_argument("--use_fp16_decoding", action="store_true", help="Whether to use fp16 decoding to predict. ")
59
    args = parser.parse_args()
60
    return args
61

62

63
def do_predict(args):
64
    place = "gpu"
65
    place = paddle.set_device(place)
66

67
    tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path)
68
    model = BartForConditionalGeneration.from_pretrained(args.model_name_or_path)
69
    # Set evaluate mode
70
    model.eval()
71
    sentences = [
72
        "I love that girl, but <mask> does not <mask> me.",
73
        "She is so <mask> that I can not help glance at <mask>.",
74
        "Nothing's gonna <mask> my love for you.",
75
        "Drop everything now. Meet me in the pouring <mask>. Kiss me on the sidewalk.",
76
    ]
77

78
    input_ids = prepare_input(tokenizer, sentences)
79

80
    # Define model
81
    model.eval()
82

83
    num_loop = 100
84
    with paddle.no_grad():
85
        for i in range(num_loop):
86
            # For warmup.
87
            if 50 == i:
88
                # PaddlePaddle >= 2.2
89
                paddle.device.cuda.synchronize(place)
90
                start = time.perf_counter()
91
            model.generate(
92
                input_ids=input_ids,
93
                max_length=args.max_length,
94
                decode_strategy=args.decode_strategy,
95
                top_k=args.top_k,
96
                top_p=args.top_p,
97
                num_beams=args.num_beams,
98
                early_stopping=True,
99
                use_fast=True,
100
                use_fp16_decoding=args.use_fp16_decoding,
101
            )
102
        paddle.device.cuda.synchronize(place)
103
        fast_cost = (time.perf_counter() - start) / 50 * 1000
104

105
    if args.use_fp16_decoding:
106
        pprint(args)
107
        print("Fast FP16 cost:", fast_cost)
108
        return
109

110
    with paddle.no_grad():
111
        for i in range(num_loop):
112
            # For warmup.
113
            if 50 == i:
114
                # PaddlePaddle >= 2.2
115
                paddle.device.cuda.synchronize(place)
116
                start = time.perf_counter()
117
            model.generate(
118
                input_ids=input_ids,
119
                max_length=args.max_length,
120
                decode_strategy=args.decode_strategy,
121
                top_k=args.top_k,
122
                top_p=args.top_p,
123
                num_beams=args.num_beams,
124
                early_stopping=True,
125
            )
126
        paddle.device.cuda.synchronize(place)
127
        pd_cost = (time.perf_counter() - start) / 50 * 1000
128

129
    device = torch.device("cuda:0")
130
    hf_model = hf_bart_model.from_pretrained("facebook/" + args.model_name_or_path)
131
    hf_model.to(device)
132
    hf_model.eval()
133
    hf_input_ids = prepare_input(tokenizer, sentences)
134
    hf_input_ids = torch.tensor(hf_input_ids.numpy())
135
    hf_input_ids = hf_input_ids.to(device)
136

137
    if args.decode_strategy == "sampling":
138
        do_sample = True
139
    else:
140
        do_sample = False
141
    with torch.no_grad():
142
        for i in range(num_loop):
143
            # For warmup.
144
            if 50 == i:
145
                torch.cuda.synchronize()
146
                start = time.perf_counter()
147
            hf_model.generate(
148
                hf_input_ids,
149
                do_sample=do_sample,
150
                max_length=args.max_length + 1,
151
                top_k=args.top_k,
152
                top_p=args.top_p,
153
                num_beams=args.num_beams,
154
                no_repeat_ngram_size=0,
155
                length_penalty=0.0,
156
            )
157
        torch.cuda.synchronize()
158
        hf_cost = (time.perf_counter() - start) / 50 * 1000
159

160
    pprint(args)
161
    print("Fast FP32 cost:", fast_cost)
162
    print("PD cost:", pd_cost)
163
    print("HF cost:", hf_cost)
164
    print("Speed up Fast FP32/PD:", pd_cost / fast_cost)
165
    print("Speed up Fast FP32/HF:", hf_cost / fast_cost)
166

167

168
if __name__ == "__main__":
169
    args = parse_args()
170
    do_predict(args)
171

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.