pytorch

Форк
0
/
test_complex.py 
178 строк · 9.4 Кб
1
# Owner(s): ["module: complex"]
2

3
import torch
4
from torch.testing._internal.common_device_type import (
5
    instantiate_device_type_tests,
6
    dtypes,
7
    onlyCPU,
8
)
9
from torch.testing._internal.common_utils import TestCase, run_tests, set_default_dtype
10
from torch.testing._internal.common_dtype import complex_types
11

12
devices = (torch.device('cpu'), torch.device('cuda:0'))
13

14
class TestComplexTensor(TestCase):
15
    @dtypes(*complex_types())
16
    def test_to_list(self, device, dtype):
17
        # test that the complex float tensor has expected values and
18
        # there's no garbage value in the resultant list
19
        self.assertEqual(torch.zeros((2, 2), device=device, dtype=dtype).tolist(), [[0j, 0j], [0j, 0j]])
20

21
    @dtypes(torch.float32, torch.float64)
22
    def test_dtype_inference(self, device, dtype):
23
        # issue: https://github.com/pytorch/pytorch/issues/36834
24
        with set_default_dtype(dtype):
25
            x = torch.tensor([3., 3. + 5.j], device=device)
26
        self.assertEqual(x.dtype, torch.cdouble if dtype == torch.float64 else torch.cfloat)
27

28
    @dtypes(*complex_types())
29
    def test_conj_copy(self, device, dtype):
30
        # issue: https://github.com/pytorch/pytorch/issues/106051
31
        x1 = torch.tensor([5 + 1j, 2 + 2j], device=device, dtype=dtype)
32
        xc1 = torch.conj(x1)
33
        x1.copy_(xc1)
34
        self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype))
35

36
    @onlyCPU
37
    @dtypes(*complex_types())
38
    def test_eq(self, device, dtype):
39
        "Test eq on complex types"
40
        nan = float("nan")
41
        # Non-vectorized operations
42
        for a, b in (
43
            (torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
44
             torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype)),
45
            (torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
46
             torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype)),
47
            (torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
48
             torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype)),
49
        ):
50
            actual = torch.eq(a, b)
51
            expected = torch.tensor([False], device=device, dtype=torch.bool)
52
            self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
53

54
            actual = torch.eq(a, a)
55
            expected = torch.tensor([True], device=device, dtype=torch.bool)
56
            self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
57

58
            actual = torch.full_like(b, complex(2, 2))
59
            torch.eq(a, b, out=actual)
60
            expected = torch.tensor([complex(0)], device=device, dtype=dtype)
61
            self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
62

63
            actual = torch.full_like(b, complex(2, 2))
64
            torch.eq(a, a, out=actual)
65
            expected = torch.tensor([complex(1)], device=device, dtype=dtype)
66
            self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
67

68
        # Vectorized operations
69
        for a, b in (
70
            (torch.tensor([
71
                -0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(2.8871, nan), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j,
72
                -0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(nan, -3.2650), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j],
73
                device=device, dtype=dtype),
74
             torch.tensor([
75
                -6.1278 - 8.5019j, 0.5886 + 8.8816j, complex(2.8871, nan), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j,
76
                -6.1278 - 2.1172j, 5.1576 + 8.8816j, complex(nan, -3.2650), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j],
77
                device=device, dtype=dtype)),
78
        ):
79
            actual = torch.eq(a, b)
80
            expected = torch.tensor([False, False, False, False, False, True,
81
                                    False, False, False, False, False, True],
82
                                    device=device, dtype=torch.bool)
83
            self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
84

85
            actual = torch.eq(a, a)
86
            expected = torch.tensor([True, True, False, True, True, True,
87
                                    True, True, False, True, True, True],
88
                                    device=device, dtype=torch.bool)
89
            self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
90

91
            actual = torch.full_like(b, complex(2, 2))
92
            torch.eq(a, b, out=actual)
93
            expected = torch.tensor([complex(0), complex(0), complex(0), complex(0), complex(0), complex(1),
94
                                    complex(0), complex(0), complex(0), complex(0), complex(0), complex(1)],
95
                                    device=device, dtype=dtype)
96
            self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
97

98
            actual = torch.full_like(b, complex(2, 2))
99
            torch.eq(a, a, out=actual)
100
            expected = torch.tensor([complex(1), complex(1), complex(0), complex(1), complex(1), complex(1),
101
                                    complex(1), complex(1), complex(0), complex(1), complex(1), complex(1)],
102
                                    device=device, dtype=dtype)
103
            self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
104

105
    @onlyCPU
106
    @dtypes(*complex_types())
107
    def test_ne(self, device, dtype):
108
        "Test ne on complex types"
109
        nan = float("nan")
110
        # Non-vectorized operations
111
        for a, b in (
112
            (torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
113
             torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype)),
114
            (torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
115
             torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype)),
116
            (torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
117
             torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype)),
118
        ):
119
            actual = torch.ne(a, b)
120
            expected = torch.tensor([True], device=device, dtype=torch.bool)
121
            self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
122

123
            actual = torch.ne(a, a)
124
            expected = torch.tensor([False], device=device, dtype=torch.bool)
125
            self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
126

127
            actual = torch.full_like(b, complex(2, 2))
128
            torch.ne(a, b, out=actual)
129
            expected = torch.tensor([complex(1)], device=device, dtype=dtype)
130
            self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
131

132
            actual = torch.full_like(b, complex(2, 2))
133
            torch.ne(a, a, out=actual)
134
            expected = torch.tensor([complex(0)], device=device, dtype=dtype)
135
            self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
136

137
        # Vectorized operations
138
        for a, b in (
139
            (torch.tensor([
140
                -0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(2.8871, nan), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j,
141
                -0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(nan, -3.2650), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j],
142
                device=device, dtype=dtype),
143
             torch.tensor([
144
                -6.1278 - 8.5019j, 0.5886 + 8.8816j, complex(2.8871, nan), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j,
145
                -6.1278 - 2.1172j, 5.1576 + 8.8816j, complex(nan, -3.2650), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j],
146
                device=device, dtype=dtype)),
147
        ):
148
            actual = torch.ne(a, b)
149
            expected = torch.tensor([True, True, True, True, True, False,
150
                                    True, True, True, True, True, False],
151
                                    device=device, dtype=torch.bool)
152
            self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
153

154
            actual = torch.ne(a, a)
155
            expected = torch.tensor([False, False, True, False, False, False,
156
                                    False, False, True, False, False, False],
157
                                    device=device, dtype=torch.bool)
158
            self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
159

160
            actual = torch.full_like(b, complex(2, 2))
161
            torch.ne(a, b, out=actual)
162
            expected = torch.tensor([complex(1), complex(1), complex(1), complex(1), complex(1), complex(0),
163
                                    complex(1), complex(1), complex(1), complex(1), complex(1), complex(0)],
164
                                    device=device, dtype=dtype)
165
            self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
166

167
            actual = torch.full_like(b, complex(2, 2))
168
            torch.ne(a, a, out=actual)
169
            expected = torch.tensor([complex(0), complex(0), complex(1), complex(0), complex(0), complex(0),
170
                                    complex(0), complex(0), complex(1), complex(0), complex(0), complex(0)],
171
                                    device=device, dtype=dtype)
172
            self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
173

174
instantiate_device_type_tests(TestComplexTensor, globals())
175

176
if __name__ == '__main__':
177
    TestCase._default_dtype_check_enabled = True
178
    run_tests()
179

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.