paddlenlp

Форк
0
64 строки · 1.8 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from __future__ import absolute_import, division, print_function
16

17
import argparse
18
import os
19
import sys
20

21
import paddle.distributed.fleet as fleet
22

23
__dir__ = os.path.dirname(os.path.abspath(__file__))
24
sys.path.append(os.path.abspath(os.path.join(__dir__, "../", "../")))
25

26
from ppfleetx.core.engine.inference_engine import InferenceEngine
27
from ppfleetx.data import tokenizers
28

29

30
def parse_args():
31
    parser = argparse.ArgumentParser()
32
    parser.add_argument("--mp_degree", default=1, type=int, help="")
33
    parser.add_argument("--model_dir", default="output", type=str, help="model directory")
34

35
    args = parser.parse_args()
36
    return args
37

38

39
def main():
40

41
    args = parse_args()
42

43
    fleet.init(is_collective=True)
44
    infer_engine = InferenceEngine(args.model_dir, args.mp_degree)
45

46
    tokenizer = tokenizers.GPTTokenizer.from_pretrained("gpt2")
47
    input_text = "Hi, GPT2. Tell me where is Beijing?"
48
    ids = [tokenizer.encode(input_text)]
49

50
    # run test
51

52
    outs = infer_engine.predict([ids])
53

54
    ids = list(outs.values())[0]
55
    out_ids = [int(x) for x in ids[0]]
56
    result = tokenizer.decode(out_ids)
57
    result = input_text + result
58

59
    print("Prompt:", input_text)
60
    print("Generation:", result)
61

62

63
if __name__ == "__main__":
64
    main()
65

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.