pytorch

Форк
0
/
test_prims.py 
459 строк · 16.8 Кб
1
# Owner(s): ["module: decompositions"]
2

3
from functools import partial
4
from itertools import product
5
import unittest
6

7
import torch
8
from torch.testing import make_tensor
9
from torch.testing._internal.common_utils import (parametrize, run_tests, TestCase, TEST_SCIPY,
10
                                                  set_default_dtype)
11
from torch.testing._internal.common_device_type import (
12
    instantiate_device_type_tests,
13
    onlyCUDA,
14
    dtypes,
15
    OpDTypes,
16
)
17
from torch.testing._internal.common_methods_invocations import (
18
    op_db,
19
)
20
from torch.testing._internal.common_device_type import (
21
    ops,
22
)
23

24
from torch.testing._internal.logging_tensor import LoggingTensor, capture_logs, log_input
25
import torch._prims as prims
26
from torch._prims_common import CUDARngStateHelper
27
from torch._prims.executor import make_traced
28
import torch._refs as refs
29

30

31
if TEST_SCIPY:
32
    import scipy.special
33

34
NVPRIM_ATEN_FALLBACK_WARNING = "fallback to aten executor"
35
GET_ISOLATED_GRAPHMODULE_ERROR = "get_isolated_graphmodule failed on decomposition"
36

37
class TestPrims(TestCase):
38
    @onlyCUDA
39
    @dtypes(torch.float32)
40
    def test_broadcast_in_dim(self, device, dtype):
41
        def _wrapper(a, b, broadcast_dimensions):
42
            return prims.broadcast_in_dim(a, b.shape, broadcast_dimensions)
43

44
        traced = make_traced(_wrapper)
45
        make_arg = partial(make_tensor, device=device, dtype=dtype)
46

47
        for executor in ('aten',):
48
            fn = partial(traced, executor=executor)
49
            # Same shape
50
            shape = (5, 5)
51
            a = make_arg(shape)
52
            b = make_arg(shape, low=0.0, high=0.0)
53
            result = fn(a, b, (0, 1))
54

55
            self.assertEqual(result.shape, a.shape)
56
            self.assertTrue(result.is_contiguous)
57
            self.assertEqual(a, result)
58

59
            # Error input: reordering dims
60
            with self.assertRaises(Exception):
61
                result = fn(a, b, (1, 0))
62

63
            # Adding outermost dimensions
64
            a = make_arg((5, 5))
65
            b = make_arg((3, 3, 5, 5), low=0.0, high=0.0)
66
            result = fn(a, b, (2, 3))
67

68
            self.assertEqual(result.shape, b.shape)
69
            self.assertEqual(a.broadcast_to(b.shape), result)
70

71
            # Expands
72
            a = make_arg((1, 5, 1))
73
            b = make_arg((3, 5, 7), low=0.0, high=0.0)
74
            result = fn(a, b, (0, 1, 2))
75

76
            self.assertEqual(result.shape, b.shape)
77
            self.assertEqual(a.expand_as(result), result)
78

79
            # Unsqueezes
80
            a = make_arg((1, 2, 3))
81
            b = make_arg((1, 2, 1, 3), low=0.0, high=0.0)
82
            result = fn(a, b, (0, 1, 3))
83

84
            self.assertEqual(result.shape, b.shape)
85
            self.assertEqual(a.unsqueeze(2), result)
86

87
    @onlyCUDA
88
    @dtypes(torch.float32)
89
    def test_broadcast_in_dim_sum(self, device, dtype):
90
        def _wrapper(a):
91
            a_sum = prims.sum(a, [0, 1])
92
            a_bc = prims.broadcast_in_dim(a_sum, [], [])
93
            return a_bc
94

95
        traced = make_traced(_wrapper)
96
        make_arg = partial(make_tensor, device=device, dtype=dtype)
97

98
        for executor in ('aten',):
99
            fn = partial(traced, executor=executor)
100
            shape = (5, 5)
101
            a = make_arg(shape)
102
            result = fn(a)
103

104
            self.assertEqual(result.shape, ())
105
            self.assertTrue(result.is_contiguous)
106
            self.assertEqual(_wrapper(a), result)
107

108
    @unittest.skipIf(not TEST_SCIPY, "SciPy not found")
109
    @dtypes(torch.float64, torch.long)
110
    def test_cbrt_prim(self, device, dtype):
111
        make_arg = partial(make_tensor, device=device, dtype=dtype)
112
        batches = [(), (1,), (2,), (0, 1), (1, 1), (2, 2)]
113
        shapes = [(), (0,), (1,), (5,)]
114

115
        # Sets the default dtype to NumPy's default dtype of double
116
        with set_default_dtype(torch.double):
117
            # Tested here, as this OP is not currently exposed or tested in ATen
118
            for b, s in product(batches, shapes):
119
                x = make_arg(b + s)
120
                y = prims.cbrt(x)
121

122
                x_np = x.cpu().numpy()
123
                y_np = scipy.special.cbrt(x_np)
124

125
                self.assertEqual(y, y_np, exact_device=False)
126

127
    @dtypes(torch.float32)
128
    def test_collapse(self, device, dtype):
129
        t = torch.rand(2, 2, 2)
130
        dim_ranges = [(0, 0), (0, 1), (1, 2), (0, 2)]
131
        expected_shapes = [(2, 2, 2), (4, 2), (2, 4), (8,)]
132

133
        for (start, end), shape in zip(dim_ranges, expected_shapes):
134
            expect = t.reshape(shape)
135

136
            copy = prims.collapse(t, start, end)
137
            self.assertEqual(copy, expect)
138
            self.assertFalse(copy._is_view())
139

140
            view = prims.collapse_view(t, start, end)
141
            self.assertEqual(view, expect)
142
            self.assertTrue(view._is_view())
143

144
        t_discontig = t.transpose(0, 1)
145
        with self.assertRaises(ValueError, msg="no such view exists"):
146
            view = prims.collapse_view(t_discontig, 0, 2)
147

148
        copy = prims.collapse(t_discontig, 0, 1)
149
        self.assertEqual(copy, t_discontig.reshape(4, 2))
150

151
        error_dims = [(-1, 1), (0, 3), (1, -1)]
152
        for start, end in error_dims:
153
            for fn in [prims.collapse, prims.collapse_view]:
154
                with self.assertRaises(AssertionError):
155
                    fn(t, start, end)
156

157

158
    def test_aten_overload_to_prims(self, device):
159
        # This test is to ensure that the torch.ops.aten calls are replaced with refs
160
        from torch.fx.experimental.proxy_tensor import make_fx
161
        from torch._prims.context import TorchRefsMode
162

163
        a = torch.randn(3, 3, device=device)
164

165
        def func(a):
166
            return torch.ops.aten.sigmoid.default(torch.ops.aten.digamma.default(a))
167

168
        with TorchRefsMode():
169
            gm = make_fx(func)(a)
170

171
        # Check that all call_function nodes are prims
172
        call_function_nodes = list(filter(lambda n: n.op == "call_function", gm.graph.nodes))
173
        all_prims_namespace = all(
174
            node.target.name().startswith("prims") for node in call_function_nodes
175
        )
176
        self.assertTrue(all_prims_namespace)
177

178
    @onlyCUDA
179
    @dtypes(torch.float32)
180
    @parametrize("correction", [0, 1])
181
    def test_var(self, device, dtype, correction):
182
        def _wrapper(a):
183
            return prims.var(a, [0, 1], correction=correction)
184

185
        traced = make_traced(_wrapper)
186
        make_arg = partial(make_tensor, device=device, dtype=dtype)
187

188
        for executor in ('aten',):
189
            fn = partial(traced, executor=executor)
190
            shape = (5, 5)
191
            a = make_arg(shape)
192
            result = fn(a)
193

194
            self.assertEqual(result.shape, ())
195
            self.assertTrue(result.is_contiguous)
196
            self.assertEqual(_wrapper(a), result)
197

198
    @dtypes(torch.float32)
199
    def test_memory_format_strides(self, device, dtype):
200
        shapes = (
201
            (),
202
            (0,),
203
            (1,),
204
            (5),
205
            (1, 0),
206
            (1, 1),
207
            (3, 7),
208
            (3, 0, 2),
209
            (1, 1, 2),
210
            (4, 1, 1),
211
            (7, 8, 9),
212
        )
213

214
        channels_last_shapes = (
215
            (0, 0, 0, 0),
216
            (1, 0, 3, 0),
217
            (0, 2, 3, 5),
218
            (2, 2, 2, 0),
219
            (5, 4, 3, 2),
220
            (8, 8, 7, 2),
221
            (9, 1, 3, 1),
222
            (4, 5, 8, 7)
223
        )
224

225
        channels_last_3d_shapes = (
226
            (0, 8, 7, 9, 2),
227
            (5, 0, 7, 9, 2),
228
            (5, 0, 7, 9, 0),
229
            (5, 8, 7, 9, 2),
230
            (5, 1, 7, 9, 2),
231
            (5, 1, 7, 9, 1),
232
        )
233

234
        pairs = (
235
            (shapes, torch.contiguous_format),
236
            (channels_last_shapes, torch.contiguous_format),
237
            (channels_last_3d_shapes, torch.contiguous_format),
238
            (channels_last_shapes, torch.channels_last),
239
            (channels_last_3d_shapes, torch.channels_last_3d),
240
        )
241

242
        for shapes, memory_format in pairs:
243
            for shape in shapes:
244
                # tests empty
245
                expected = torch.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
246
                actual = refs.empty(shape, device=device, dtype=dtype, memory_format=memory_format)
247
                self.assertEqual(expected.stride(), actual.stride())
248

249
                # tests clone
250
                a = torch.testing.make_tensor(shape, device=device, dtype=dtype)
251
                expected = torch.clone(a, memory_format=memory_format)
252
                actual = torch.clone(a, memory_format=memory_format)
253
                self.assertEqual(expected.stride(), actual.stride())
254

255
                # tests contiguous
256
                a = torch.testing.make_tensor(shape, device=device, dtype=dtype, noncontiguous=True)
257
                expected = a.contiguous(memory_format=memory_format)
258
                actual = refs.contiguous(a, memory_format=memory_format)
259
                self.assertEqual(expected.stride(), actual.stride())
260

261
    @dtypes(torch.float32)
262
    def test_reshape_view_method(self, device, dtype):
263
        make_arg = partial(make_tensor, device=device, dtype=dtype)
264
        a = make_arg((5, 5))
265
        new_shape = 1, 5, 1, 5
266
        result_eager = a.reshape(*new_shape)
267
        result_refs = refs.reshape(a, *new_shape)
268
        self.assertEqual(result_eager, result_refs)
269

270
        result_eager = a.view(*new_shape)
271
        result_refs = refs.view(a, *new_shape)
272
        self.assertEqual(result_eager, result_refs)
273

274

275
    @onlyCUDA
276
    @dtypes(torch.float32)
277
    def test_philox_rand(self, device, dtype):
278
        sizes = (1000, 1000000)  # offsets of 4 and 8
279
        repeats = 2  # Checks multiple rand calls results with multiple philox_rand calls
280
        for size in sizes:
281
            torch.cuda.manual_seed(123)
282
            references = []
283
            results = []
284
            rng_states = []
285
            for _ in range(repeats):
286
                rng_states.append(CUDARngStateHelper.get_torch_state_as_tuple())
287
                references.append(torch.rand(size, device=device, dtype=dtype))
288

289
            torch.cuda.manual_seed(123)
290
            for idx in range(repeats):
291
                seed, offset = rng_states[idx]
292
                result, _ = torch.ops.rngprims.philox_rand((size,),
293
                                                           seed=seed,
294
                                                           offset=offset,
295
                                                           stride=None,
296
                                                           device=device,
297
                                                           dtype=dtype)
298
                results.append(result)
299

300
            for a, b in zip(references, results):
301
                self.assertEqual(a, b)
302

303

304
    @dtypes(torch.float32)
305
    def test_functional_rng_wrappers(self, device, dtype):
306

307
        torch.manual_seed(123)
308
        ref1 = torch.rand(10, device=device, dtype=dtype)
309
        ref2 = torch.rand(10, device=device, dtype=dtype)
310

311

312
        torch.manual_seed(123)
313
        rng_state1, res1 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
314
        rng_state2, res2 = torch._prims.rng_prims.run_and_save_rng_state(torch.rand, 10, device=device, dtype=dtype)
315

316
        res3 = torch._prims.rng_prims.run_with_rng_state(rng_state1, torch.rand, 10, device=device, dtype=dtype)
317
        res4 = torch._prims.rng_prims.run_with_rng_state(rng_state2, torch.rand, 10, device=device, dtype=dtype)
318

319
        self.assertEqual(ref1, res1)
320
        self.assertEqual(ref2, res2)
321
        self.assertEqual(ref1, res3)
322
        self.assertEqual(ref2, res4)
323

324
class TestPrimsBasic(TestCase):
325
    def test_torch_ops(self):
326
        r = make_tensor((2,), device='cpu', dtype=torch.float)
327
        self.assertEqual(torch.ops.prims.sin(r), torch.sin(r))
328

329
        r = LoggingTensor(r)
330
        with capture_logs() as logs:
331
            log_input("input", r)
332
            prims.sin(r)
333
        self.assertExpectedInline('\n'.join(logs), """\
334
$0: f32[2] = input('input')
335
$1: f32[2] = torch._ops.prims.sin.default($0)""")
336

337
    def test_mul_complex(self):
338
        prims.mul(torch.randn(2), 1 + 1j)
339

340
    def test_check_deprecation_warning(self):
341
        with self.assertWarnsRegex(DeprecationWarning, 'will be removed in the future'):
342
            torch._prims_common.check(True, lambda: 'message')
343

344

345
instantiate_device_type_tests(TestPrims, globals())
346

347

348
class TestRefs(TestCase):
349
    @dtypes(torch.float32)
350
    def test_constant_pad_nd_memory_format(self, device, dtype):
351
        # Test memory format is preserved in unambiguous cases
352
        for mf, ndim in (
353
                (torch.channels_last, 4),
354
                (torch.contiguous_format, 4),
355
                (torch.channels_last_3d, 5),
356
                (torch.contiguous_format, 5),
357
        ):
358
            a = torch.zeros([2] * ndim).to(memory_format=mf)
359
            res = refs.constant_pad_nd(a, pad=[1] * (2 * ndim))
360
            self.assertTrue(res.is_contiguous(memory_format=mf))
361

362
        # Ambiguous cases
363

364
        # is_channels_last_ and is_contiguous_, results in channels_last output
365
        a = torch.empty_strided((2, 1, 2, 2), stride=(4, 1, 2, 1))
366
        self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
367
        self.assertTrue(a.is_contiguous())
368
        actual = refs.constant_pad_nd(a, pad=[1] * 8)
369
        expect = torch.constant_pad_nd(a, pad=[1] * 8)
370
        self.assertEqual(actual.stride(), expect.stride())
371
        self.assertTrue(actual.is_contiguous(memory_format=torch.channels_last))
372

373
        # is_channels_last_contiguous_ but not is_channels_last_, results in
374
        # contiguous output
375
        a = torch.empty_strided((2, 1, 2, 2), stride=(4, 4, 2, 1))
376
        self.assertTrue(a.is_contiguous(memory_format=torch.channels_last))
377
        self.assertTrue(a.is_contiguous())
378
        actual = refs.constant_pad_nd(a, pad=[1] * 8)
379
        expect = torch.constant_pad_nd(a, pad=[1] * 8)
380
        self.assertEqual(actual.stride(), expect.stride())
381
        self.assertTrue(actual.is_contiguous())
382

383
    def test_unbind(self):
384
        # If unbind returns empty tuple, it breaks some assumptions in some backward tests in test_ops.py.
385
        # So can't put this test into common_methods_invocations.py.
386
        a = torch.rand([3, 0, 4])
387
        actual = refs.unbind(a, 1)
388
        expect = torch.unbind(a, 1)
389
        self.assertEqual(actual, expect)
390

391
    def test_logspace_with_complex_input(self):
392
        actual = refs.logspace(2, 10 + 5j, steps=5)
393
        expect = torch.logspace(2, 10 + 5j, steps=5)
394
        self.assertEqual(actual, expect)
395

396
    def test_linspace_with_complex_input(self):
397
        actual = refs.linspace(2, 10 + 5j, steps=5)
398
        expect = torch.linspace(2, 10 + 5j, steps=5)
399
        self.assertEqual(actual, expect)
400

401
    # From https://github.com/pytorch/pytorch/issues/109558
402
    def test_infinite_loop_from_py_dispatcher(self):
403
        # enables prim decomps
404
        with torch._dispatch.python.enable_python_dispatcher():
405
            x = torch.ones(4)
406
            y = x.to(device="meta")
407

408

409
instantiate_device_type_tests(TestRefs, globals())
410

411

412
class TestDecomp(TestCase):
413
    @ops([op for op in op_db if op.supports_varargs], dtypes=OpDTypes.any_one)
414
    def test_decomposition_method_vararg(self, device, dtype, op):
415
        # some ops have vararg variants for the methods. this tests it.
416
        # we don't have tests for varargs in OpInfo, so we need to
417
        # improvise this a bit.
418
        # The rule for general functions (the special cases being e.g. tensor
419
        # creation functions taking shapes) is that things can be vararg
420
        # if the method has only one argument of sequence type.
421
        # e.g. permute can be called on a 3d tensor t as t.permute(0, 2, 1)
422
        #      as well as t.permute([0, 2, 1])
423
        #      when the signature in native_functions.yaml
424
        #      shows arguments Tensor self, IntList dims
425
        # we might need to adjust things for the factory functions or
426
        # have them do their own test
427
        from torch.fx.experimental.proxy_tensor import make_fx
428
        from torch._prims.context import TorchRefsMode
429

430
        # filter out empty tuple as that cannot be the varargs
431
        sample_inputs = (si for si in op.sample_inputs(device, dtype, requires_grad=False)
432
                         if (si.args[-1] if si.args else si.input))
433

434
        # just run one test, we assume there is a suitable one in the tests
435
        sample_input = next(sample_inputs)
436
        all_args = (sample_input.input,) + sample_input.args
437

438
        # in general, the methods take varargs and not (always?) the function
439
        # variants, the exception to this rule are the factory functions
440
        if op.is_factory_function:
441
            fn = op.op
442
        else:
443
            fn = op.method_variant
444
        with TorchRefsMode():
445
            gm = make_fx(fn)(*all_args[:-1], *all_args[-1])
446

447
        # in case we add random factory functions
448
        torch.manual_seed(1)
449
        res = gm(*all_args[:-1], *all_args[-1])
450
        torch.manual_seed(1)
451
        expected = fn(*all_args[:-1], *all_args[-1])
452
        self.assertEqual(res, expected)
453

454

455
instantiate_device_type_tests(TestDecomp, globals())
456

457

458
if __name__ == "__main__":
459
    run_tests()
460

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.