paddlenlp

Форк
0
/
testing_utils.py 
110 строк · 3.6 Кб
1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
from __future__ import annotations
15

16
import json
17
import os
18
import shutil
19
import sys
20
import tempfile
21

22
import paddle
23

24
from tests.testing_utils import argv_context_guard, load_test_config
25

26

27
class LLMTest:
28
    config_path: str = None
29
    data_dir = "./tests/fixtures/llm/data/"
30

31
    def setUp(self) -> None:
32
        self.root_path = "./llm"
33
        self.output_dir = tempfile.mkdtemp()
34
        self.inference_output_dir = tempfile.mkdtemp()
35
        sys.path.insert(0, self.root_path)
36
        self.disable_static()
37
        paddle.set_default_dtype("float32")
38

39
    def tearDown(self) -> None:
40
        sys.path.remove(self.root_path)
41
        shutil.rmtree(self.output_dir)
42
        shutil.rmtree(self.inference_output_dir)
43
        self.disable_static()
44
        paddle.device.cuda.empty_cache()
45

46
    def disable_static(self):
47
        paddle.utils.unique_name.switch()
48
        paddle.disable_static()
49

50
    def _read_result(self, file):
51
        result = []
52
        # read output field from json file
53
        with open(file, "r", encoding="utf-8") as f:
54
            for line in f:
55
                data = json.loads(line)
56
                result.append(data["output"])
57
        return result
58

59
    def run_predictor(self, config_params=None):
60
        if config_params is None:
61
            config_params = {}
62

63
        # to avoid the same parameter
64
        self.disable_static()
65
        predict_config = load_test_config(self.config_path, "inference-predict")
66
        predict_config["output_file"] = os.path.join(self.output_dir, "predict.json")
67
        predict_config["model_name_or_path"] = self.output_dir
68
        predict_config.update(config_params)
69

70
        with argv_context_guard(predict_config):
71
            from predictor import predict
72

73
            predict()
74

75
        # prefix_tuning dynamic graph do not support to_static
76
        if not predict_config["inference_model"]:
77
            return
78

79
        # to static
80
        self.disable_static()
81
        config = load_test_config(self.config_path, "inference-to-static")
82
        config["output_path"] = self.inference_output_dir
83
        config["model_name_or_path"] = self.output_dir
84
        config.update(config_params)
85
        with argv_context_guard(config):
86
            from export_model import main
87

88
            main()
89

90
        # inference
91
        self.disable_static()
92
        config = load_test_config(self.config_path, "inference-infer")
93
        config["model_name_or_path"] = self.inference_output_dir
94
        config["output_file"] = os.path.join(self.inference_output_dir, "infer.json")
95

96
        config_params.pop("model_name_or_path", None)
97
        config.update(config_params)
98
        with argv_context_guard(config):
99
            from predictor import predict
100

101
            predict()
102

103
        self.disable_static()
104

105
        predict_result = self._read_result(predict_config["output_file"])
106
        infer_result = self._read_result(config["output_file"])
107
        assert len(predict_result) == len(infer_result)
108

109
        for predict_item, infer_item in zip(predict_result, infer_result):
110
            self.assertEqual(predict_item, infer_item)
111

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.