paddlenlp

Форк
0
78 строк · 2.5 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
import argparse
16
import os
17
import sys
18
import time
19

20
import numpy as np
21
import paddle.distributed.fleet as fleet
22

23
__dir__ = os.path.dirname(os.path.abspath(__file__))
24
sys.path.append(os.path.abspath(os.path.join(__dir__, "../", "../")))
25

26
from ppfleetx.core.engine.inference_engine import InferenceEngine
27

28

29
def parse_args():
30
    parser = argparse.ArgumentParser()
31
    parser.add_argument("--seq_len", default=128, type=int, required=False, help="seq length of inputs")
32
    parser.add_argument("--iter", default=100, type=int, help="run iterations for timing")
33
    parser.add_argument("--mp_degree", default=1, type=int, help="")
34
    parser.add_argument("--model_dir", default="output", type=str, help="model directory")
35

36
    args = parser.parse_args()
37
    return args
38

39

40
def predict(engine, data, args):
41

42
    with engine._static_guard:
43
        for d, name in zip(data, engine.input_names()):
44
            handle = engine.predictor.get_input_handle(name)
45
            handle.copy_from_cpu(d)
46

47
        for _ in range(10):
48
            engine.predictor.run()
49
        engine.predictor.get_output_handle(engine.output_names()[0]).copy_to_cpu()
50

51
        start = time.perf_counter()
52
        for _ in range(args.iter):
53
            engine.predictor.run()
54
        end = time.perf_counter()
55
        print(f"batch {data.shape} run time: {1000 * (end - start) / args.iter}ms")
56

57
        return {name: engine.predictor.get_output_handle(name).copy_to_cpu() for name in engine.output_names()}
58

59

60
def main():
61

62
    args = parse_args()
63

64
    fleet.init(is_collective=True)
65
    infer_engine = InferenceEngine(args.model_dir, args.mp_degree)
66
    ids = [100] * args.seq_len
67

68
    # run test
69
    for batch in [1, 2, 4, 8, 16]:
70

71
        whole_data = [ids] * batch
72
        whole_data = np.array(whole_data, dtype="int64").reshape(1, batch, -1)
73

74
        _ = predict(infer_engine, whole_data, args)
75

76

77
if __name__ == "__main__":
78
    main()
79

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.