1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
7
# http://www.apache.org/licenses/LICENSE-2.0
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
20
__all__ = ["CommonTest", "CpuCommonTest"]
23
# Assume all elements has same data type
24
def get_container_type(container):
25
container_t = type(container)
26
if container_t in [list, tuple]:
27
if len(container) == 0:
29
return get_container_type(container[0])
33
class CommonTest(unittest.TestCase):
34
def __init__(self, methodName="runTest"):
35
super(CommonTest, self).__init__(methodName=methodName)
38
if paddle.is_compiled_with_cuda():
39
self.places.append("gpu")
44
Set the decorators for all test function
46
for key, value in cls.__dict__.items():
47
if key.startswith("test"):
48
decorator_func_list = ["_test_places", "_catch_warnings"]
49
for decorator_func in decorator_func_list:
50
decorator_func = getattr(CommonTest, decorator_func)
51
value = decorator_func(value)
52
setattr(cls, key, value)
54
def _catch_warnings(func):
56
Catch the warnings and treat them as errors for each test.
59
def wrapper(self, *args, **kwargs):
60
with warnings.catch_warnings(record=True) as w:
61
warnings.resetwarnings()
62
# ignore specified warnings
63
warning_white_list = [UserWarning]
64
for warning in warning_white_list:
65
warnings.simplefilter("ignore", warning)
66
func(self, *args, **kwargs)
67
msg = None if len(w) == 0 else w[0].message
68
self.assertFalse(len(w) > 0, msg)
72
def _test_places(func):
74
Setting the running place for each test.
77
def wrapper(self, *args, **kwargs):
80
paddle.set_device(place)
81
func(self, *args, **kwargs)
85
def _check_output_impl(self, result, expected_result, rtol, atol, equal=True):
86
assertForNormalType = self.assertNotEqual
87
assertForFloat = self.assertFalse
89
assertForNormalType = self.assertEqual
90
assertForFloat = self.assertTrue
92
result_t = type(result)
93
error_msg = "Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}"
94
if result_t in [list, tuple]:
95
result_t = get_container_type(result)
96
if result_t in [str, int, bool, set, bool, np.int32, np.int64]:
100
msg=error_msg.format(paddle.get_device(), expected_result, result, self.__class__.__name__),
102
elif result_t in [float, np.ndarray, np.float32, np.float64]:
104
np.allclose(result, expected_result, rtol=rtol, atol=atol),
105
msg=error_msg.format(paddle.get_device(), expected_result, result, self.__class__.__name__),
107
if result_t == np.ndarray:
110
expected_result.shape,
111
msg=error_msg.format(
112
paddle.get_device(), expected_result.shape, result.shape, self.__class__.__name__
117
"result type must be str, int, bool, set, np.bool, np.int32, "
118
"np.int64, np.str, float, np.ndarray, np.float32, np.float64"
121
def check_output_equal(self, result, expected_result, rtol=1.0e-5, atol=1.0e-8):
123
Check whether result and expected result are equal, including shape.
125
result: str, int, bool, set, np.ndarray.
126
The result needs to be checked.
127
expected_result: str, int, bool, set, np.ndarray. The type has to be same as result's.
128
Use the expected result to check result.
130
relative tolerance, default 1.e-5.
132
absolute tolerance, default 1.e-8
134
self._check_output_impl(result, expected_result, rtol, atol)
136
def check_output_not_equal(self, result, expected_result, rtol=1.0e-5, atol=1.0e-8):
138
Check whether result and expected result are not equal, including shape.
140
result: str, int, bool, set, np.ndarray.
141
The result needs to be checked.
142
expected_result: str, int, bool, set, np.ndarray. The type has to be same as result's.
143
Use the expected result to check result.
145
relative tolerance, default 1.e-5.
147
absolute tolerance, default 1.e-8
149
self._check_output_impl(result, expected_result, rtol, atol, equal=False)
152
class CpuCommonTest(CommonTest):
153
def __init__(self, methodName="runTest"):
154
super(CpuCommonTest, self).__init__(methodName=methodName)
155
self.places = ["cpu"]