paddlenlp

Форк
0
168 строк · 6.4 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import os
17
import time
18
from pprint import pprint
19

20
import paddle
21

22
from paddlenlp.data import DataCollatorWithPadding
23
from paddlenlp.ops import enable_ft_para, get_ft_para_conf
24
from paddlenlp.trainer.argparser import strtobool
25
from paddlenlp.transformers import (
26
    UnifiedTransformerLMHeadModel,
27
    UnifiedTransformerTokenizer,
28
)
29

30

31
def parse_args():
32
    parser = argparse.ArgumentParser()
33
    parser.add_argument("--use_role", type=strtobool, default=True, help="Whether to use role embeddings.")
34
    parser.add_argument(
35
        "--position_style",
36
        default="relative",
37
        choices=["continuous", "relative"],
38
        type=str,
39
        help="The type for positional embedding. Default is relative.",
40
    )
41
    parser.add_argument("--batch_size", default=1, type=int, help="Batch size.")
42
    parser.add_argument(
43
        "--num_return_sequences", default=1, type=int, help="The number of returned sequences for each sample."
44
    )
45
    parser.add_argument("--max_out_len", default=64, type=int, help="Maximum output sequence length.")
46
    parser.add_argument("--min_out_len", default=1, type=int, help="Minimum output sequence length.")
47
    parser.add_argument(
48
        "--topk", default=1, type=int, help="The number of highest probability tokens to keep for top-k-sampling."
49
    )
50
    parser.add_argument("--topp", default=1.0, type=float, help="The cumulative probability for top-p-filtering.")
51
    parser.add_argument("--temperature", default=1.0, type=float, help="The temperature to set.")
52
    parser.add_argument("--use_faster", action="store_true", help="Whether to use faster generation. ")
53
    parser.add_argument(
54
        "--use_fp16",
55
        action="store_true",
56
        help="Whether to use fp16 to predict. Only available when `use_faster` is True.",
57
    )
58
    parser.add_argument("--profile", action="store_true", help="Whether to profile.")
59
    args = parser.parse_args()
60
    return args
61

62

63
def profile(batch_size, total_step=50, warmup_step=10, rank=0):
64
    def _wrapper(func):
65
        def _impl(*args, **kwargs):
66
            for i in range(total_step):
67
                if i == warmup_step:
68
                    paddle.device.cuda.synchronize()
69
                    start_time = time.time()
70
                out = func(*args, **kwargs)
71
            paddle.device.cuda.synchronize()
72
            end_time = time.time()
73
            if rank is None or get_ft_para_conf().rank == rank:
74
                time_interval = end_time - start_time
75
                num_batch = total_step - warmup_step
76
                print("Latency: %2fs, QPS: %2f" % (time_interval / num_batch, num_batch * batch_size / time_interval))
77
            return out
78

79
        return _impl
80

81
    return _wrapper
82

83

84
def postprocess_response(token_ids, tokenizer):
85
    """Post-process the decoded sequence. Truncate from the first <eos>."""
86
    eos_pos = len(token_ids)
87
    for i, tok_id in enumerate(token_ids):
88
        if tok_id == tokenizer.sep_token_id:
89
            eos_pos = i
90
            break
91
    token_ids = token_ids[:eos_pos]
92
    tokens = tokenizer.convert_ids_to_tokens(token_ids)
93
    tokens = tokenizer.merge_subword(tokens)
94
    response = " ".join(tokens)
95
    return response
96

97

98
def main(args):
99
    # For memory saving when using FastGeneration:
100
    # If environment variable `PPFG_QKV_MEM_OPT` is set and the weights of q/k/v
101
    # is fused, it will try to delete the original unfused weights. Note the
102
    # rollback to original model would not be guarantee anymore when the faster
103
    # model failed if the original weights are deleted.
104
    os.environ["PPFG_QKV_MEM_OPT"] = "1"
105
    if args.use_fp16:
106
        assert args.use_faster, "Only supports FP16 when using FastGeneration."
107
        paddle.set_default_dtype("float16")
108
    enable_ft_para()
109
    # TODO(guosheng): Maybe device can be set in `enable_ft_para`
110
    paddle.set_device("gpu:" + str(get_ft_para_conf().rank))
111

112
    if args.profile:
113
        UnifiedTransformerLMHeadModel.generate = profile(args.batch_size)(UnifiedTransformerLMHeadModel.generate)
114
    tokenizer = UnifiedTransformerTokenizer.from_pretrained("plato-xl")
115
    model = UnifiedTransformerLMHeadModel.from_pretrained("plato-xl")
116
    model.eval()
117

118
    history = [
119
        "hi , Mary ! What do you usually like to do in your spare time ?",
120
        "well , I spend a lot of time watching movies .",
121
        "what a confidence ! I always watch a lot of movies , too ."
122
        "oh really , Frank ? What kind of movies do you like ?",
123
    ]
124
    inputs = [history] * args.batch_size
125
    inputs = list(
126
        map(
127
            lambda history: tokenizer.dialogue_encode(
128
                history=history,
129
                add_start_token_as_response=True,
130
                return_length=True,
131
                return_role_ids=args.use_role,
132
                position_style=args.position_style,
133
            ),
134
            inputs,
135
        )
136
    )
137
    collator = DataCollatorWithPadding(tokenizer)
138
    data = collator(inputs)
139

140
    outputs, _ = model.generate(
141
        input_ids=data["input_ids"],
142
        token_type_ids=data["token_type_ids"],
143
        position_ids=data["position_ids"],
144
        attention_mask=data["attention_mask"].cast("float32"),  # TODO(guosheng): remove this cast
145
        role_ids=data.get("role_ids", None),
146
        seq_len=data["seq_len"],
147
        max_length=args.max_out_len,
148
        min_length=args.min_out_len,
149
        decode_strategy="sampling",
150
        top_k=args.topk,
151
        top_p=args.topp,
152
        temperature=args.temperature,
153
        num_return_sequences=args.num_return_sequences,
154
        use_fast=args.use_faster,
155
        use_fp16_decoding=args.use_fp16,
156
    )
157

158
    # Only make the first process to output.
159
    if get_ft_para_conf().rank == 0:
160
        for i in range(len(outputs)):
161
            result = postprocess_response(outputs[i].numpy(), tokenizer)
162
            print("Result:", result)
163

164

165
if __name__ == "__main__":
166
    args = parse_args()
167
    pprint(args)
168
    main(args)
169

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.