paddlenlp
170 строк · 5.9 Кб
1# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3#
4# Licensed under the Apache License, Version 2.0 (the "License");
5# you may not use this file except in compliance with the License.
6# You may obtain a copy of the License at
7#
8# http://www.apache.org/licenses/LICENSE-2.0
9#
10# Unless required by applicable law or agreed to in writing, software
11# distributed under the License is distributed on an "AS IS" BASIS,
12# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13# See the License for the specific language governing permissions and
14# limitations under the License.
15
16import argparse
17import time
18from pprint import pprint
19
20import paddle
21import torch
22from transformers import BartForConditionalGeneration as hf_bart_model
23
24from paddlenlp.data import Pad
25from paddlenlp.transformers import BartForConditionalGeneration, BartTokenizer
26
27
28def prepare_input(tokenizer, sentences):
29word_pad = Pad(tokenizer.pad_token_id, dtype="int64")
30tokenized = tokenizer(sentences)
31inputs = word_pad([i["input_ids"] for i in tokenized])
32input_ids = paddle.to_tensor(inputs)
33return input_ids
34
35
36def parse_args():
37parser = argparse.ArgumentParser()
38parser.add_argument(
39"--model_name_or_path",
40default="bart-base",
41type=str,
42choices=["bart-base", "bart-large"],
43help="The model name to specify the bart to use. Can be one of ['bart-base', 'bart-large']. ",
44)
45parser.add_argument(
46"--decode_strategy",
47default="sampling",
48type=str,
49choices=["greedy_search", "beam_search", "sampling"],
50help="The decoding strategy. Can be one of ['greedy_search', 'beam_search', 'sampling']",
51)
52parser.add_argument("--num_beams", default=4, type=int, help="The parameters for beam search. ")
53parser.add_argument("--top_k", default=4, type=int, help="The number of candidate to procedure beam search. ")
54parser.add_argument(
55"--top_p", default=1.0, type=float, help="The probability threshold to procedure topp sampling. "
56)
57parser.add_argument("--max_length", default=32, type=int, help="Maximum output length. ")
58parser.add_argument("--use_fp16_decoding", action="store_true", help="Whether to use fp16 decoding to predict. ")
59args = parser.parse_args()
60return args
61
62
63def do_predict(args):
64place = "gpu"
65place = paddle.set_device(place)
66
67tokenizer = BartTokenizer.from_pretrained(args.model_name_or_path)
68model = BartForConditionalGeneration.from_pretrained(args.model_name_or_path)
69# Set evaluate mode
70model.eval()
71sentences = [
72"I love that girl, but <mask> does not <mask> me.",
73"She is so <mask> that I can not help glance at <mask>.",
74"Nothing's gonna <mask> my love for you.",
75"Drop everything now. Meet me in the pouring <mask>. Kiss me on the sidewalk.",
76]
77
78input_ids = prepare_input(tokenizer, sentences)
79
80# Define model
81model.eval()
82
83num_loop = 100
84with paddle.no_grad():
85for i in range(num_loop):
86# For warmup.
87if 50 == i:
88# PaddlePaddle >= 2.2
89paddle.device.cuda.synchronize(place)
90start = time.perf_counter()
91model.generate(
92input_ids=input_ids,
93max_length=args.max_length,
94decode_strategy=args.decode_strategy,
95top_k=args.top_k,
96top_p=args.top_p,
97num_beams=args.num_beams,
98early_stopping=True,
99use_fast=True,
100use_fp16_decoding=args.use_fp16_decoding,
101)
102paddle.device.cuda.synchronize(place)
103fast_cost = (time.perf_counter() - start) / 50 * 1000
104
105if args.use_fp16_decoding:
106pprint(args)
107print("Fast FP16 cost:", fast_cost)
108return
109
110with paddle.no_grad():
111for i in range(num_loop):
112# For warmup.
113if 50 == i:
114# PaddlePaddle >= 2.2
115paddle.device.cuda.synchronize(place)
116start = time.perf_counter()
117model.generate(
118input_ids=input_ids,
119max_length=args.max_length,
120decode_strategy=args.decode_strategy,
121top_k=args.top_k,
122top_p=args.top_p,
123num_beams=args.num_beams,
124early_stopping=True,
125)
126paddle.device.cuda.synchronize(place)
127pd_cost = (time.perf_counter() - start) / 50 * 1000
128
129device = torch.device("cuda:0")
130hf_model = hf_bart_model.from_pretrained("facebook/" + args.model_name_or_path)
131hf_model.to(device)
132hf_model.eval()
133hf_input_ids = prepare_input(tokenizer, sentences)
134hf_input_ids = torch.tensor(hf_input_ids.numpy())
135hf_input_ids = hf_input_ids.to(device)
136
137if args.decode_strategy == "sampling":
138do_sample = True
139else:
140do_sample = False
141with torch.no_grad():
142for i in range(num_loop):
143# For warmup.
144if 50 == i:
145torch.cuda.synchronize()
146start = time.perf_counter()
147hf_model.generate(
148hf_input_ids,
149do_sample=do_sample,
150max_length=args.max_length + 1,
151top_k=args.top_k,
152top_p=args.top_p,
153num_beams=args.num_beams,
154no_repeat_ngram_size=0,
155length_penalty=0.0,
156)
157torch.cuda.synchronize()
158hf_cost = (time.perf_counter() - start) / 50 * 1000
159
160pprint(args)
161print("Fast FP32 cost:", fast_cost)
162print("PD cost:", pd_cost)
163print("HF cost:", hf_cost)
164print("Speed up Fast FP32/PD:", pd_cost / fast_cost)
165print("Speed up Fast FP32/HF:", hf_cost / fast_cost)
166
167
168if __name__ == "__main__":
169args = parse_args()
170do_predict(args)
171