1
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
from __future__ import annotations
24
from tests.testing_utils import argv_context_guard, load_test_config
28
config_path: str = None
29
data_dir = "./tests/fixtures/llm/data/"
31
def setUp(self) -> None:
32
self.root_path = "./llm"
33
self.output_dir = tempfile.mkdtemp()
34
self.inference_output_dir = tempfile.mkdtemp()
35
sys.path.insert(0, self.root_path)
37
paddle.set_default_dtype("float32")
39
def tearDown(self) -> None:
40
sys.path.remove(self.root_path)
41
shutil.rmtree(self.output_dir)
42
shutil.rmtree(self.inference_output_dir)
44
paddle.device.cuda.empty_cache()
46
def disable_static(self):
47
paddle.utils.unique_name.switch()
48
paddle.disable_static()
50
def _read_result(self, file):
52
# read output field from json file
53
with open(file, "r", encoding="utf-8") as f:
55
data = json.loads(line)
56
result.append(data["output"])
59
def run_predictor(self, config_params=None):
60
if config_params is None:
63
# to avoid the same parameter
65
predict_config = load_test_config(self.config_path, "inference-predict")
66
predict_config["output_file"] = os.path.join(self.output_dir, "predict.json")
67
predict_config["model_name_or_path"] = self.output_dir
68
predict_config.update(config_params)
70
with argv_context_guard(predict_config):
71
from predictor import predict
75
# prefix_tuning dynamic graph do not support to_static
76
if not predict_config["inference_model"]:
81
config = load_test_config(self.config_path, "inference-to-static")
82
config["output_path"] = self.inference_output_dir
83
config["model_name_or_path"] = self.output_dir
84
config.update(config_params)
85
with argv_context_guard(config):
86
from export_model import main
92
config = load_test_config(self.config_path, "inference-infer")
93
config["model_name_or_path"] = self.inference_output_dir
94
config["output_file"] = os.path.join(self.inference_output_dir, "infer.json")
96
config_params.pop("model_name_or_path", None)
97
config.update(config_params)
98
with argv_context_guard(config):
99
from predictor import predict
103
self.disable_static()
105
predict_result = self._read_result(predict_config["output_file"])
106
infer_result = self._read_result(config["output_file"])
107
assert len(predict_result) == len(infer_result)
109
for predict_item, infer_item in zip(predict_result, infer_result):
110
self.assertEqual(predict_item, infer_item)