1
# Owner(s): ["module: complex"]
4
from torch.testing._internal.common_device_type import (
5
instantiate_device_type_tests,
9
from torch.testing._internal.common_utils import TestCase, run_tests, set_default_dtype
10
from torch.testing._internal.common_dtype import complex_types
12
devices = (torch.device('cpu'), torch.device('cuda:0'))
14
class TestComplexTensor(TestCase):
15
@dtypes(*complex_types())
16
def test_to_list(self, device, dtype):
17
# test that the complex float tensor has expected values and
18
# there's no garbage value in the resultant list
19
self.assertEqual(torch.zeros((2, 2), device=device, dtype=dtype).tolist(), [[0j, 0j], [0j, 0j]])
21
@dtypes(torch.float32, torch.float64)
22
def test_dtype_inference(self, device, dtype):
23
# issue: https://github.com/pytorch/pytorch/issues/36834
24
with set_default_dtype(dtype):
25
x = torch.tensor([3., 3. + 5.j], device=device)
26
self.assertEqual(x.dtype, torch.cdouble if dtype == torch.float64 else torch.cfloat)
28
@dtypes(*complex_types())
29
def test_conj_copy(self, device, dtype):
30
# issue: https://github.com/pytorch/pytorch/issues/106051
31
x1 = torch.tensor([5 + 1j, 2 + 2j], device=device, dtype=dtype)
34
self.assertEqual(x1, torch.tensor([5 - 1j, 2 - 2j], device=device, dtype=dtype))
37
@dtypes(*complex_types())
38
def test_eq(self, device, dtype):
39
"Test eq on complex types"
41
# Non-vectorized operations
43
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
44
torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype)),
45
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
46
torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype)),
47
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
48
torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype)),
50
actual = torch.eq(a, b)
51
expected = torch.tensor([False], device=device, dtype=torch.bool)
52
self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
54
actual = torch.eq(a, a)
55
expected = torch.tensor([True], device=device, dtype=torch.bool)
56
self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
58
actual = torch.full_like(b, complex(2, 2))
59
torch.eq(a, b, out=actual)
60
expected = torch.tensor([complex(0)], device=device, dtype=dtype)
61
self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
63
actual = torch.full_like(b, complex(2, 2))
64
torch.eq(a, a, out=actual)
65
expected = torch.tensor([complex(1)], device=device, dtype=dtype)
66
self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
68
# Vectorized operations
71
-0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(2.8871, nan), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j,
72
-0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(nan, -3.2650), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j],
73
device=device, dtype=dtype),
75
-6.1278 - 8.5019j, 0.5886 + 8.8816j, complex(2.8871, nan), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j,
76
-6.1278 - 2.1172j, 5.1576 + 8.8816j, complex(nan, -3.2650), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j],
77
device=device, dtype=dtype)),
79
actual = torch.eq(a, b)
80
expected = torch.tensor([False, False, False, False, False, True,
81
False, False, False, False, False, True],
82
device=device, dtype=torch.bool)
83
self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
85
actual = torch.eq(a, a)
86
expected = torch.tensor([True, True, False, True, True, True,
87
True, True, False, True, True, True],
88
device=device, dtype=torch.bool)
89
self.assertEqual(actual, expected, msg=f"\neq\nactual {actual}\nexpected {expected}")
91
actual = torch.full_like(b, complex(2, 2))
92
torch.eq(a, b, out=actual)
93
expected = torch.tensor([complex(0), complex(0), complex(0), complex(0), complex(0), complex(1),
94
complex(0), complex(0), complex(0), complex(0), complex(0), complex(1)],
95
device=device, dtype=dtype)
96
self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
98
actual = torch.full_like(b, complex(2, 2))
99
torch.eq(a, a, out=actual)
100
expected = torch.tensor([complex(1), complex(1), complex(0), complex(1), complex(1), complex(1),
101
complex(1), complex(1), complex(0), complex(1), complex(1), complex(1)],
102
device=device, dtype=dtype)
103
self.assertEqual(actual, expected, msg=f"\neq(out)\nactual {actual}\nexpected {expected}")
106
@dtypes(*complex_types())
107
def test_ne(self, device, dtype):
108
"Test ne on complex types"
110
# Non-vectorized operations
112
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
113
torch.tensor([-6.1278 - 8.5019j], device=device, dtype=dtype)),
114
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
115
torch.tensor([-6.1278 - 2.1172j], device=device, dtype=dtype)),
116
(torch.tensor([-0.0610 - 2.1172j], device=device, dtype=dtype),
117
torch.tensor([-0.0610 - 8.5019j], device=device, dtype=dtype)),
119
actual = torch.ne(a, b)
120
expected = torch.tensor([True], device=device, dtype=torch.bool)
121
self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
123
actual = torch.ne(a, a)
124
expected = torch.tensor([False], device=device, dtype=torch.bool)
125
self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
127
actual = torch.full_like(b, complex(2, 2))
128
torch.ne(a, b, out=actual)
129
expected = torch.tensor([complex(1)], device=device, dtype=dtype)
130
self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
132
actual = torch.full_like(b, complex(2, 2))
133
torch.ne(a, a, out=actual)
134
expected = torch.tensor([complex(0)], device=device, dtype=dtype)
135
self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
137
# Vectorized operations
140
-0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(2.8871, nan), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j,
141
-0.0610 - 2.1172j, 5.1576 + 5.4775j, complex(nan, -3.2650), -6.6545 - 3.7655j, -2.7036 - 1.4470j, 0.3712 + 7.989j],
142
device=device, dtype=dtype),
144
-6.1278 - 8.5019j, 0.5886 + 8.8816j, complex(2.8871, nan), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j,
145
-6.1278 - 2.1172j, 5.1576 + 8.8816j, complex(nan, -3.2650), 6.3505 + 2.2683j, 0.3712 + 7.9659j, 0.3712 + 7.989j],
146
device=device, dtype=dtype)),
148
actual = torch.ne(a, b)
149
expected = torch.tensor([True, True, True, True, True, False,
150
True, True, True, True, True, False],
151
device=device, dtype=torch.bool)
152
self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
154
actual = torch.ne(a, a)
155
expected = torch.tensor([False, False, True, False, False, False,
156
False, False, True, False, False, False],
157
device=device, dtype=torch.bool)
158
self.assertEqual(actual, expected, msg=f"\nne\nactual {actual}\nexpected {expected}")
160
actual = torch.full_like(b, complex(2, 2))
161
torch.ne(a, b, out=actual)
162
expected = torch.tensor([complex(1), complex(1), complex(1), complex(1), complex(1), complex(0),
163
complex(1), complex(1), complex(1), complex(1), complex(1), complex(0)],
164
device=device, dtype=dtype)
165
self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
167
actual = torch.full_like(b, complex(2, 2))
168
torch.ne(a, a, out=actual)
169
expected = torch.tensor([complex(0), complex(0), complex(1), complex(0), complex(0), complex(0),
170
complex(0), complex(0), complex(1), complex(0), complex(0), complex(0)],
171
device=device, dtype=dtype)
172
self.assertEqual(actual, expected, msg=f"\nne(out)\nactual {actual}\nexpected {expected}")
174
instantiate_device_type_tests(TestComplexTensor, globals())
176
if __name__ == '__main__':
177
TestCase._default_dtype_check_enabled = True