paddlenlp

Форк
0
162 строки · 5.6 Кб
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
# Copyright 2018 The Google AI Language Team Authors and The HuggingFace Inc. team.
3
#
4
# Licensed under the Apache License, Version 2.0 (the "License");
5
# you may not use this file except in compliance with the License.
6
# You may obtain a copy of the License at
7
#
8
#     http://www.apache.org/licenses/LICENSE-2.0
9
#
10
# Unless required by applicable law or agreed to in writing, software
11
# distributed under the License is distributed on an "AS IS" BASIS,
12
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
# See the License for the specific language governing permissions and
14
# limitations under the License.
15

16
import argparse
17

18
# append project root dir to project to make it run with latest code
19
import sys
20
import time
21
from pprint import pprint
22

23
import numpy as np
24
import paddle
25
import torch
26
from transformers.models.opt.modeling_opt import OPTForCausalLM as hf_opt_model
27

28
from paddlenlp.transformers import GPTTokenizer, OPTForCausalLM
29

30
sys.path.insert(0, "../../")
31

32

33
def parse_args():
34
    parser = argparse.ArgumentParser()
35
    parser.add_argument(
36
        "--model_name_or_path",
37
        default="facebook/opt-125m",
38
        type=str,
39
        choices=["facebook/opt-125m", "facebook/opt-350m", "facebook/opt-1.3b", "facebook/opt-2.7b"],
40
        help="The model name to specify the bart to use. Can be one of ['facebook/opt-125m', 'facebook/opt-350m', 'facebook/opt-1.3b', 'facebook/opt-2.7b']. ",
41
    )
42
    parser.add_argument(
43
        "--decode_strategy",
44
        default="greedy_search",
45
        type=str,
46
        choices=["greedy_search", "sampling"],
47
        help="The decoding strategy. Can be one of ['greedy_search', 'sampling']",
48
    )
49
    parser.add_argument("--top_k", default=4, type=int, help="The number of candidate to procedure beam search. ")
50
    parser.add_argument("--batch_size", default=4, type=int, help="The size of input batch. ")
51
    parser.add_argument(
52
        "--top_p", default=1.0, type=float, help="The probability threshold to procedure topp sampling. "
53
    )
54
    parser.add_argument("--max_length", default=32, type=int, help="Maximum output length. ")
55
    parser.add_argument("--use_fp16_decoding", action="store_true", help="Whether to use fp16 decoding to predict. ")
56
    args = parser.parse_args()
57
    return args
58

59

60
def do_predict(args):
61
    place = "gpu"
62
    place = paddle.set_device(place)
63

64
    tokenizer = GPTTokenizer.from_pretrained(args.model_name_or_path)
65
    model = OPTForCausalLM.from_pretrained(args.model_name_or_path)
66
    # Set evaluate mode
67
    model.eval()
68
    bos_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
69
    eos_id = tokenizer.convert_tokens_to_ids("<|endoftext|>")
70

71
    input_ids_np = np.array([[bos_id] for i in range(args.batch_size)]).astype("int64").reshape([args.batch_size, 1])
72
    input_ids = paddle.to_tensor(input_ids_np)
73
    # Define model
74
    num_loop = 100
75
    with paddle.no_grad():
76
        for i in range(num_loop):
77
            # For warmup.
78
            if 50 == i:
79
                # PaddlePaddle >= 2.2
80
                paddle.device.cuda.synchronize(place)
81
                start = time.perf_counter()
82
            model.generate(
83
                input_ids=input_ids,
84
                max_length=args.max_length,
85
                decode_strategy=args.decode_strategy,
86
                top_k=args.top_k,
87
                top_p=args.top_p,
88
                bos_token_id=bos_id,
89
                eos_token_id=eos_id,
90
                use_fast=True,
91
                use_fp16_decoding=args.use_fp16_decoding,
92
            )
93
        paddle.device.cuda.synchronize(place)
94
        fast_cost = (time.perf_counter() - start) / 50 * 1000
95

96
    if args.use_fp16_decoding:
97
        pprint(args)
98
        print("Fast FP16 cost:", fast_cost)
99
        return
100
    with paddle.no_grad():
101
        for i in range(num_loop):
102
            # For warmup.
103
            if 50 == i:
104
                # PaddlePaddle >= 2.2
105
                paddle.device.cuda.synchronize(place)
106
                start = time.perf_counter()
107
            model.generate(
108
                input_ids=input_ids,
109
                max_length=args.max_length,
110
                decode_strategy=args.decode_strategy,
111
                top_k=args.top_k,
112
                top_p=args.top_p,
113
                bos_token_id=bos_id,
114
                eos_token_id=eos_id,
115
            )
116
        paddle.device.cuda.synchronize(place)
117
        pd_cost = (time.perf_counter() - start) / 50 * 1000
118

119
    device = torch.device("cuda:0")
120
    hf_model = hf_opt_model.from_pretrained(args.model_name_or_path)
121

122
    hf_model.to(device)
123
    hf_model.eval()
124

125
    hf_input_ids = torch.tensor(input_ids_np)
126
    hf_input_ids = hf_input_ids.to(device)
127

128
    if args.decode_strategy == "sampling":
129
        do_sample = True
130
    else:
131
        do_sample = False
132
    with torch.no_grad():
133
        for i in range(num_loop):
134
            # For warmup.
135
            if 50 == i:
136
                torch.cuda.synchronize()
137
                start = time.perf_counter()
138
            hf_model.generate(
139
                hf_input_ids,
140
                do_sample=do_sample,
141
                max_length=args.max_length + 1,
142
                bos_token_id=bos_id,
143
                eos_token_id=eos_id,
144
                pad_token_id=0,
145
                top_k=args.top_k,
146
                top_p=args.top_p,
147
            )
148
        torch.cuda.synchronize()
149
        hf_cost = (time.perf_counter() - start) / 50 * 1000
150

151
    pprint(args)
152
    print("Fast FP32 cost:", fast_cost)
153
    print("PD cost:", pd_cost)
154
    print("HF cost:", hf_cost)
155
    print("Speed up Fast FP32/PD:", pd_cost / fast_cost)
156
    print("Speed up Fast FP32/HF:", hf_cost / fast_cost)
157

158

159
if __name__ == "__main__":
160
    args = parse_args()
161
    print(args.model_name_or_path)
162
    do_predict(args)
163

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.