paddlenlp

Форк
0
/
common_test.py 
155 строк · 5.8 Кб
1
# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved.
2
#
3
# Licensed under the Apache License, Version 2.0 (the "License");
4
# you may not use this file except in compliance with the License.
5
# You may obtain a copy of the License at
6
#
7
#     http://www.apache.org/licenses/LICENSE-2.0
8
#
9
# Unless required by applicable law or agreed to in writing, software
10
# distributed under the License is distributed on an "AS IS" BASIS,
11
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
# See the License for the specific language governing permissions and
13
# limitations under the License.
14
import unittest
15
import warnings
16

17
import numpy as np
18
import paddle
19

20
__all__ = ["CommonTest", "CpuCommonTest"]
21

22

23
# Assume all elements has same data type
24
def get_container_type(container):
25
    container_t = type(container)
26
    if container_t in [list, tuple]:
27
        if len(container) == 0:
28
            return container_t
29
        return get_container_type(container[0])
30
    return container_t
31

32

33
class CommonTest(unittest.TestCase):
34
    def __init__(self, methodName="runTest"):
35
        super(CommonTest, self).__init__(methodName=methodName)
36
        self.config = {}
37
        self.places = ["cpu"]
38
        if paddle.is_compiled_with_cuda():
39
            self.places.append("gpu")
40

41
    @classmethod
42
    def setUpClass(cls):
43
        """
44
        Set the decorators for all test function
45
        """
46
        for key, value in cls.__dict__.items():
47
            if key.startswith("test"):
48
                decorator_func_list = ["_test_places", "_catch_warnings"]
49
                for decorator_func in decorator_func_list:
50
                    decorator_func = getattr(CommonTest, decorator_func)
51
                    value = decorator_func(value)
52
                setattr(cls, key, value)
53

54
    def _catch_warnings(func):
55
        """
56
        Catch the warnings and treat them as errors for each test.
57
        """
58

59
        def wrapper(self, *args, **kwargs):
60
            with warnings.catch_warnings(record=True) as w:
61
                warnings.resetwarnings()
62
                # ignore specified warnings
63
                warning_white_list = [UserWarning]
64
                for warning in warning_white_list:
65
                    warnings.simplefilter("ignore", warning)
66
                func(self, *args, **kwargs)
67
                msg = None if len(w) == 0 else w[0].message
68
                self.assertFalse(len(w) > 0, msg)
69

70
        return wrapper
71

72
    def _test_places(func):
73
        """
74
        Setting the running place for each test.
75
        """
76

77
        def wrapper(self, *args, **kwargs):
78
            places = self.places
79
            for place in places:
80
                paddle.set_device(place)
81
                func(self, *args, **kwargs)
82

83
        return wrapper
84

85
    def _check_output_impl(self, result, expected_result, rtol, atol, equal=True):
86
        assertForNormalType = self.assertNotEqual
87
        assertForFloat = self.assertFalse
88
        if equal:
89
            assertForNormalType = self.assertEqual
90
            assertForFloat = self.assertTrue
91

92
        result_t = type(result)
93
        error_msg = "Output has diff at place:{}. \nExpect: {} \nBut Got: {} in class {}"
94
        if result_t in [list, tuple]:
95
            result_t = get_container_type(result)
96
        if result_t in [str, int, bool, set, bool, np.int32, np.int64]:
97
            assertForNormalType(
98
                result,
99
                expected_result,
100
                msg=error_msg.format(paddle.get_device(), expected_result, result, self.__class__.__name__),
101
            )
102
        elif result_t in [float, np.ndarray, np.float32, np.float64]:
103
            assertForFloat(
104
                np.allclose(result, expected_result, rtol=rtol, atol=atol),
105
                msg=error_msg.format(paddle.get_device(), expected_result, result, self.__class__.__name__),
106
            )
107
            if result_t == np.ndarray:
108
                assertForNormalType(
109
                    result.shape,
110
                    expected_result.shape,
111
                    msg=error_msg.format(
112
                        paddle.get_device(), expected_result.shape, result.shape, self.__class__.__name__
113
                    ),
114
                )
115
        else:
116
            raise ValueError(
117
                "result type must be str, int, bool, set, np.bool, np.int32, "
118
                "np.int64, np.str, float, np.ndarray, np.float32, np.float64"
119
            )
120

121
    def check_output_equal(self, result, expected_result, rtol=1.0e-5, atol=1.0e-8):
122
        """
123
            Check whether result and expected result are equal, including shape.
124
        Args:
125
            result: str, int, bool, set, np.ndarray.
126
                The result needs to be checked.
127
            expected_result: str, int, bool, set, np.ndarray. The type has to be same as result's.
128
                Use the expected result to check result.
129
            rtol: float
130
                relative tolerance, default 1.e-5.
131
            atol: float
132
                absolute tolerance, default 1.e-8
133
        """
134
        self._check_output_impl(result, expected_result, rtol, atol)
135

136
    def check_output_not_equal(self, result, expected_result, rtol=1.0e-5, atol=1.0e-8):
137
        """
138
            Check whether result and expected result are not equal, including shape.
139
        Args:
140
            result: str, int, bool, set, np.ndarray.
141
                The result needs to be checked.
142
            expected_result: str, int, bool, set, np.ndarray. The type has to be same as result's.
143
                Use the expected result to check result.
144
            rtol: float
145
                relative tolerance, default 1.e-5.
146
            atol: float
147
                absolute tolerance, default 1.e-8
148
        """
149
        self._check_output_impl(result, expected_result, rtol, atol, equal=False)
150

151

152
class CpuCommonTest(CommonTest):
153
    def __init__(self, methodName="runTest"):
154
        super(CpuCommonTest, self).__init__(methodName=methodName)
155
        self.places = ["cpu"]
156

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.