3
from functools import partial
4
from itertools import product
8
from torch.testing import make_tensor
9
from torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY,
11
from torch.testing._internal.common_device_type import (
12
instantiate_device_type_tests,
17
from torch.testing._internal.common_methods_invocations import (
20
from torch.testing._internal.common_device_type import (
24
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input
25
import torch._prims as prims
26
from torch._prims_common import CUDARngStateHelper
27
from torch._prims.executor import make_traced
28
import torch._refs as refs
34
NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
35
GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition"
37
class TestPrims(TestCase):
39
@dtypes(torch.float32)
40
def test_broadcast_in_dim(self, device, dtype):
41
def _wrapper(a, b, broadcast_dimensions):
42
return prims.broadcast_in_dim(a, b.shape, broadcast_dimensions)
44
traced = make_traced(_wrapper)
45
make_arg = partial(make_tensor, device=device, dtype=dtype)
47
for executor in ('aten',):
48
fn = partial(traced, executor=executor)
52
b = make_arg(shape, low=0.0, high=0.0)
53
result = fn(a, b, (0, 1))
55
self.assertEqual(result.shape, a.shape)
56
self.assertTrue(result.is_contiguous)
57
self.assertEqual(a, result)
60
with self.assertRaises(Exception):
61
result = fn(a, b, (1, 0))
65
b = make_arg((3, 3, 5, 5), low=0.0, high=0.0)
66
result = fn(a, b, (2, 3))
68
self.assertEqual(result.shape, b.shape)
69
self.assertEqual(a.broadcast_to(b.shape), result)
72
a = make_arg((1, 5, 1))
73
b = make_arg((3, 5, 7), low=0.0, high=0.0)
74
result = fn(a, b, (0, 1, 2))
76
self.assertEqual(result.shape, b.shape)
77
self.assertEqual(a.expand_as(result), result)
80
a = make_arg((1, 2, 3))
81
b = make_arg((1, 2, 1, 3), low=0.0, high=0.0)
82
result = fn(a, b, (0, 1, 3))
84
self.assertEqual(result.shape, b.shape)
85
self.assertEqual(a.unsqueeze(2), result)
88
@dtypes(torch.float32)
89
def test_broadcast_in_dim_sum(self, device, dtype):
91
a_sum = prims.sum(a, [0, 1])
92
a_bc = prims.broadcast_in_dim(a_sum, [], [])
95
traced = make_traced(_wrapper)
96
make_arg = partial(make_tensor, device=device, dtype=dtype)
98
for executor in ('aten',):
99
fn = partial(traced, executor=executor)
104
self.assertEqual(result.shape, ())
105
self.assertTrue(result.is_contiguous)
106
self.assertEqual(_wrapper(a), result)
108
@unittest.skipIf(not TEST_SCIPY, "SciPy not found")
109
@dtypes(torch.float64, torch.long)
110
def test_cbrt_prim(self, device, dtype):
111
make_arg = partial(make_tensor, device=device, dtype=dtype)
112
batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)]
113
shapes = [(), (0,), (1,), (5,)]
116
with set_default_dtype(torch.double):
118
for b, s in product(batches, shapes):
122
x_np = x.cpu().numpy()
123
y_np = scipy.special.cbrt(x_np)
125
self.assertEqual(y, y_np, exact_device=False)
127
@dtypes(torch.float32)
128
def test_collapse(self, device, dtype):
129
t = torch.rand(2, 2, 2)
130
dim_ranges = [(0, 0), (0, 1), (1, 2), (0, 2)]
131
expected_shapes = [(2, 2, 2), (4, 2), (2, 4), (8,)]
133
for (start, end), shape in zip(dim_ranges, expected_shapes):
134
expect = t.reshape(shape)
136
copy = prims.collapse(t, start, end)
137
self.assertEqual(copy, expect)
138
self.assertFalse(copy._is_view())
140
view = prims.collapse_view(t, start, end)
141
self.assertEqual(view, expect)
142
self.assertTrue(view._is_view())
144
t_discontig = t.transpose(0, 1)
145
with self.assertRaises(ValueError, msg="no such view exists"):
146
view = prims.collapse_view(t_discontig, 0, 2)
148
copy = prims.collapse(t_discontig, 0, 1)
149
self.assertEqual(copy, t_discontig.reshape(4, 2))
151
error_dims = [(-1, 1), (0, 3), (1, -1)]
152
for start, end in error_dims:
153
for fn in [prims.collapse, prims.collapse_view]:
154
with self.assertRaises(AssertionError):
158
def test_aten_overload_to_prims(self, device):
160
from torch.fx.experimental.proxy_tensor import make_fx
161
from torch._prims.context import TorchRefsMode
163
a = torch.randn(3, 3, device=device)
166
return torch.ops.aten.sigmoid.default(torch.ops.aten.digamma.default(a))
168
with TorchRefsMode():
169
gm = make_fx(func)(a)
172
call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
173
all_prims_namespace = all(
174
node.target.name().startswith("prims") for node in call_function_nodes
176
self.assertTrue(all_prims_namespace)
179
@dtypes(torch.float32)
180
@parametrize("correction", [0, 1])
181
def test_var(self, device, dtype, correction):
183
return prims.var(a, [0, 1], correction=correction)
185
traced = make_traced(_wrapper)
186
make_arg = partial(make_tensor, device=device, dtype=dtype)
188
for executor in ('aten',):
189
fn = partial(traced, executor=executor)
194
self.assertEqual(result.shape, ())
195
self.assertTrue(result.is_contiguous)
196
self.assertEqual(_wrapper(a), result)
198
@dtypes(torch.float32)
199
def test_memory_format_strides(self, device, dtype):
214
channels_last_shapes = (
225
channels_last_3d_shapes = (
235
(shapes, torch.contiguous_format),
236
(channels_last_shapes, torch.contiguous_format),
237
(channels_last_3d_shapes, torch.contiguous_format),
238
(channels_last_shapes, torch.channels_last),
239
(channels_last_3d_shapes, torch.channels_last_3d),
242
for shapes, memory_format in pairs:
245
expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
246
actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
247
self.assertEqual(expected.stride(), actual.stride())
250
a = torch.testing.make_tensor(shape, device=device, dtype=dtype)
251
expected = torch.clone(a, memory_format=memory_format)
252
actual = torch.clone(a, memory_format=memory_format)
253
self.assertEqual(expected.stride(), actual.stride())
256
a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True)
257
expected = a.contiguous(memory_format=memory_format)
258
actual = refs.contiguous(a, memory_format=memory_format)
259
self.assertEqual(expected.stride(), actual.stride())
261
@dtypes(torch.float32)
262
def test_reshape_view_method(self, device, dtype):
263
make_arg = partial(make_tensor, device=device, dtype=dtype)
265
new_shape = 1, 5, 1, 5
266
result_eager = a.reshape(*new_shape)
267
result_refs = refs.reshape(a, *new_shape)
268
self.assertEqual(result_eager, result_refs)
270
result_eager = a.view(*new_shape)
271
result_refs = refs.view(a, *new_shape)
272
self.assertEqual(result_eager, result_refs)
276
@dtypes(torch.float32)
277
def test_philox_rand(self, device, dtype):
278
sizes = (1000, 1000000)
281
torch.cuda.manual_seed(123)
285
for _ in range(repeats):
286
rng_states.append(CUDARngStateHelper.get_torch_state_as_tuple())
287
references.append(torch.rand(size, device=device, dtype=dtype))
289
torch.cuda.manual_seed(123)
290
for idx in range(repeats):
291
seed, offset = rng_states[idx]
292
result, _ = torch.ops.rngprims.philox_rand((size,),
298
results.append(result)
300
for a, b in zip(references, results):
301
self.assertEqual(a, b)
304
@dtypes(torch.float32)
305
def test_functional_rng_wrappers(self, device, dtype):
307
torch.manual_seed(123)
308
ref1 = torch.rand(10, device=device, dtype=dtype)
309
ref2 = torch.rand(10, device=device, dtype=dtype)
312
torch.manual_seed(123)
313
rng_state1, res1 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
314
rng_state2, res2 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
316
res3 = torch._prims.rng_prims.run_with_rng_state(rng_state1, torch.rand, 10, device=device, dtype=dtype)
317
res4 = torch._prims.rng_prims.run_with_rng_state(rng_state2, torch.rand, 10, device=device, dtype=dtype)
319
self.assertEqual(ref1, res1)
320
self.assertEqual(ref2, res2)
321
self.assertEqual(ref1, res3)
322
self.assertEqual(ref2, res4)
324
class TestPrimsBasic(TestCase):
325
def test_torch_ops(self):
326
r = make_tensor((2,), device='cpu', dtype=torch.float)
327
self.assertEqual(torch.ops.prims.sin(r), torch.sin(r))
330
with capture_logs() as logs:
331
log_input("input", r)
333
self.assertExpectedInline('\n'.join(logs), """\
334
$0: f32[2] = input('input')
335
$1: f32[2] = torch._ops.prims.sin.default($0)""")
337
def test_mul_complex(self):
338
prims.mul(torch.randn(2), 1 + 1j)
340
def test_check_deprecation_warning(self):
341
with self.assertWarnsRegex(DeprecationWarning, 'will be removed in the future'):
342
torch._prims_common.check(True, lambda: 'message')
345
instantiate_device_type_tests(TestPrims, globals())
348
class TestRefs(TestCase):
349
@dtypes(torch.float32)
350
def test_constant_pad_nd_memory_format(self, device, dtype):
353
(torch.channels_last, 4),
354
(torch.contiguous_format, 4),
355
(torch.channels_last_3d, 5),
356
(torch.contiguous_format, 5),
358
a = torch.zeros([2] * ndim).to(memory_format=mf)
359
res = refs.constant_pad_nd(a, pad=[1] * (2 * ndim))
360
self.assertTrue(res.is_contiguous(memory_format=mf))
365
a = torch.empty_strided((2, 1, 2, 2), stride=(4, 1, 2, 1))
366
self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
367
self.assertTrue(a.is_contiguous())
368
actual = refs.constant_pad_nd(a, pad=[1] * 8)
369
expect = torch.constant_pad_nd(a, pad=[1] * 8)
370
self.assertEqual(actual.stride(), expect.stride())
371
self.assertTrue(actual.is_contiguous(memory_format=torch.channels_last))
375
a = torch.empty_strided((2, 1, 2, 2), stride=(4, 4, 2, 1))
376
self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
377
self.assertTrue(a.is_contiguous())
378
actual = refs.constant_pad_nd(a, pad=[1] * 8)
379
expect = torch.constant_pad_nd(a, pad=[1] * 8)
380
self.assertEqual(actual.stride(), expect.stride())
381
self.assertTrue(actual.is_contiguous())
383
def test_unbind(self):
386
a = torch.rand([3, 0, 4])
387
actual = refs.unbind(a, 1)
388
expect = torch.unbind(a, 1)
389
self.assertEqual(actual, expect)
391
def test_logspace_with_complex_input(self):
392
actual = refs.logspace(2, 10 + 5j, steps=5)
393
expect = torch.logspace(2, 10 + 5j, steps=5)
394
self.assertEqual(actual, expect)
396
def test_linspace_with_complex_input(self):
397
actual = refs.linspace(2, 10 + 5j, steps=5)
398
expect = torch.linspace(2, 10 + 5j, steps=5)
399
self.assertEqual(actual, expect)
402
def test_infinite_loop_from_py_dispatcher(self):
404
with torch._dispatch.python.enable_python_dispatcher():
406
y = x.to(device="meta")
409
instantiate_device_type_tests(TestRefs, globals())
412
class TestDecomp(TestCase):
413
@ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one)
414
def test_decomposition_method_vararg(self, device, dtype, op):
427
from torch.fx.experimental.proxy_tensor import make_fx
428
from torch._prims.context import TorchRefsMode
431
sample_inputs = (si for si in op.sample_inputs(device, dtype, requires_grad=False)
432
if (si.args[-1] if si.args else si.input))
435
sample_input = next(sample_inputs)
436
all_args = (sample_input.input,) + sample_input.args
440
if op.is_factory_function:
443
fn = op.method_variant
444
with TorchRefsMode():
445
gm = make_fx(fn)(*all_args[:-1], *all_args[-1])
449
res = gm(*all_args[:-1], *all_args[-1])
451
expected = fn(*all_args[:-1], *all_args[-1])
452
self.assertEqual(res, expected)
455
instantiate_device_type_tests(TestDecomp, globals())
458
if __name__ == "__main__":