paddlenlp

Форк
0
/
test_ernie_vil.py 
67 строк · 2.3 Кб
1
# Copyright (c) 2023 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14

15
from __future__ import annotations
16

17
import os
18
import sys
19
from unittest import TestCase
20

21
from paddlenlp.utils import install_package
22
from paddlenlp.utils.downloader import get_path_from_url
23
from tests.testing_utils import argv_context_guard, load_test_config
24

25

26
class ErnieViLTest(TestCase):
27
    def setUp(self) -> None:
28
        self.path = "./model_zoo/ernie-vil2.0"
29
        self.config_path = "./tests/fixtures/model_zoo/ernie_vil.yaml"
30
        sys.path.insert(0, self.path)
31

32
    def tearDown(self) -> None:
33
        sys.path.remove(self.path)
34

35
    def test_finetune(self):
36
        install_package("lmdb", "1.3.0")
37
        if not os.path.exists("./tests/fixtures/Flickr30k-CN"):
38
            URL = "https://paddlenlp.bj.bcebos.com/tests/Flickr30k-CN-small.zip"
39
            get_path_from_url(URL, root_dir="./tests/fixtures")
40

41
        # 1. run finetune
42
        finetune_config = load_test_config(self.config_path, "finetune")
43
        with argv_context_guard(finetune_config):
44
            from run_finetune import do_train
45

46
            do_train()
47

48
        # 2. export model
49
        export_config = {
50
            "model_path": finetune_config["output_dir"],
51
            "output_path": finetune_config["output_dir"],
52
        }
53
        with argv_context_guard(export_config):
54
            from export_model import main
55

56
            main()
57

58
        # 3. infer model
59
        infer_config = {
60
            "image_path": "./tests/fixtures/tests_samples/COCO/000000039769.png",
61
            "model_dir": export_config["output_path"],
62
            "device": finetune_config["device"],
63
        }
64
        with argv_context_guard(infer_config):
65
            from deploy.python.infer import main
66

67
            main()
68

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.