1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
18
from pprint import pprint
22
from paddlenlp.data import DataCollatorWithPadding
23
from paddlenlp.ops import enable_ft_para, get_ft_para_conf
24
from paddlenlp.trainer.argparser import strtobool
25
from paddlenlp.transformers import (
26
UnifiedTransformerLMHeadModel,
27
UnifiedTransformerTokenizer,
32
parser = argparse.ArgumentParser()
33
parser.add_argument("--use_role", type=strtobool, default=True, help="Whether to use role embeddings.")
37
choices=["continuous", "relative"],
39
help="The type for positional embedding. Default is relative.",
41
parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
43
"--num_return_sequences", default=1, type=int, help="The number of returned sequences for each sample."
45
parser.add_argument("--max_out_len", default=64, type=int, help="Maximum output sequence length.")
46
parser.add_argument("--min_out_len", default=1, type=int, help="Minimum output sequence length.")
48
"--topk", default=1, type=int, help="The number of highest probability tokens to keep for top-k-sampling."
50
parser.add_argument("--topp", default=1.0, type=float, help="The cumulative probability for top-p-filtering.")
51
parser.add_argument("--temperature", default=1.0, type=float, help="The temperature to set.")
52
parser.add_argument("--use_faster", action="store_true", help="Whether to use faster generation. ")
56
help="Whether to use fp16 to predict. Only available when `use_faster` is True.",
58
parser.add_argument("--profile", action="store_true", help="Whether to profile.")
59
args = parser.parse_args()
63
def profile(batch_size, total_step=50, warmup_step=10, rank=0):
65
def _impl(*args, **kwargs):
66
for i in range(total_step):
68
paddle.device.cuda.synchronize()
69
start_time = time.time()
70
out = func(*args, **kwargs)
71
paddle.device.cuda.synchronize()
72
end_time = time.time()
73
if rank is None or get_ft_para_conf().rank == rank:
74
time_interval = end_time - start_time
75
num_batch = total_step - warmup_step
76
print("Latency: %2fs, QPS: %2f" % (time_interval / num_batch, num_batch * batch_size / time_interval))
84
def postprocess_response(token_ids, tokenizer):
85
"""Post-process the decoded sequence. Truncate from the first <eos>."""
86
eos_pos = len(token_ids)
87
for i, tok_id in enumerate(token_ids):
88
if tok_id == tokenizer.sep_token_id:
91
token_ids = token_ids[:eos_pos]
92
tokens = tokenizer.convert_ids_to_tokens(token_ids)
93
tokens = tokenizer.merge_subword(tokens)
94
response = " ".join(tokens)
99
# For memory saving when using FastGeneration:
100
# If environment variable `PPFG_QKV_MEM_OPT` is set and the weights of q/k/v
101
# is fused, it will try to delete the original unfused weights. Note the
102
# rollback to original model would not be guarantee anymore when the faster
103
# model failed if the original weights are deleted.
104
os.environ["PPFG_QKV_MEM_OPT"] = "1"
106
assert args.use_faster, "Only supports FP16 when using FastGeneration."
107
paddle.set_default_dtype("float16")
109
# TODO(guosheng): Maybe device can be set in `enable_ft_para`
110
paddle.set_device("gpu:" + str(get_ft_para_conf().rank))
113
UnifiedTransformerLMHeadModel.generate = profile(args.batch_size)(UnifiedTransformerLMHeadModel.generate)
114
tokenizer = UnifiedTransformerTokenizer.from_pretrained("plato-xl")
115
model = UnifiedTransformerLMHeadModel.from_pretrained("plato-xl")
119
"hi , Mary ! What do you usually like to do in your spare time ?",
120
"well , I spend a lot of time watching movies .",
121
"what a confidence ! I always watch a lot of movies , too ."
122
"oh really , Frank ? What kind of movies do you like ?",
124
inputs = [history] * args.batch_size
127
lambda history: tokenizer.dialogue_encode(
129
add_start_token_as_response=True,
131
return_role_ids=args.use_role,
132
position_style=args.position_style,
137
collator = DataCollatorWithPadding(tokenizer)
138
data = collator(inputs)
140
outputs, _ = model.generate(
141
input_ids=data["input_ids"],
142
token_type_ids=data["token_type_ids"],
143
position_ids=data["position_ids"],
144
attention_mask=data["attention_mask"].cast("float32"), # TODO(guosheng): remove this cast
145
role_ids=data.get("role_ids", None),
146
seq_len=data["seq_len"],
147
max_length=args.max_out_len,
148
min_length=args.min_out_len,
149
decode_strategy="sampling",
152
temperature=args.temperature,
153
num_return_sequences=args.num_return_sequences,
154
use_fast=args.use_faster,
155
use_fp16_decoding=args.use_fp16,
158
# Only make the first process to output.
159
if get_ft_para_conf().rank == 0:
160
for i in range(len(outputs)):
161
result = postprocess_response(outputs[i].numpy(), tokenizer)
162
print("Result:", result)
165
if __name__ == "__main__":