pytorch

Форк
0
/
test_with_effects.py 
907 строк · 34.6 Кб
1
# Owner(s): ["module: functorch"]
2
# flake8: noqa: B950
3
import unittest
4
from collections import deque
5
from functools import partial
6
from typing import List, TYPE_CHECKING
7

8
import torch
9
import torch._dynamo
10
import torch._functorch
11
import torch._inductor
12
import torch._inductor.decomposition
13
from functorch.compile import (
14
    aot_function,
15
    default_decompositions,
16
    min_cut_rematerialization_partition,
17
    nop,
18
)
19
from torch._functorch.aot_autograd import aot_export_module
20
from torch._higher_order_ops.effects import with_effects
21
from torch._higher_order_ops.torchbind import enable_torchbind_tracing
22
from torch.fx.experimental.proxy_tensor import make_fx
23
from torch.testing import FileCheck
24
from torch.testing._internal.common_cuda import (
25
    _get_torch_cuda_version,
26
    SM70OrLater,
27
    SM80OrLater,
28
)
29
from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
30
from torch.testing._internal.common_utils import (
31
    IS_WINDOWS,
32
    run_tests,
33
    skipIfTorchDynamo,
34
    TEST_CUDA,
35
    TEST_WITH_ROCM,
36
    TestCase,
37
)
38
from torch.testing._internal.torchbind_impls import init_torchbind_implementations
39

40

41
if TYPE_CHECKING:
42
    from torch.utils.hooks import RemovableHandle
43

44
from torch.testing._internal.two_tensor import TwoTensor
45

46

47
def extract_graph(fx_g, _, graph_cell):
48
    graph_cell[0] = fx_g
49
    return fx_g
50

51

52
def get_fw_bw_graph(
53
    f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False
54
):
55
    fw_graph_cell = [None]
56
    bw_graph_cell = [None]
57
    requires_grad = False
58

59
    def fn_req_grad(t):
60
        nonlocal requires_grad
61
        requires_grad = requires_grad or t.requires_grad
62
        return t
63

64
    torch.utils._pytree.tree_map_only(torch.Tensor, fn_req_grad, inps)
65

66
    out = aot_function(
67
        f,
68
        fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
69
        bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell)
70
        if requires_grad
71
        else nop,
72
        partition_fn=partitioner,
73
        decompositions=default_decompositions,
74
        dynamic=dynamic,
75
    )(*inps)
76

77
    if requires_grad:
78
        out.sum().backward()
79

80
    return (fw_graph_cell[0], bw_graph_cell[0])
81

82

83
def make_inputs_non_leaves(inps):
84
    return torch.utils._pytree.tree_map_only(torch.Tensor, lambda t: t.add(1), inps)
85

86

87
@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
88
class TestWithEffects(TestCase):
89
    def setUp(self):
90
        init_torchbind_implementations()
91

92
    def test_print(self):
93
        class M(torch.nn.Module):
94
            def forward(self, x):
95
                torch.ops.aten._print("moo")
96
                res = x + x
97
                torch.ops.aten._print("moo")
98
                return (res,)
99

100
        inputs = (torch.randn(3),)
101

102
        # Without functionalization, print should just appear in the graph directly
103
        gm = make_fx(M())(*inputs)
104
        FileCheck().check_count("torch.ops.aten._print.default", 2, exactly=True).run(
105
            gm.code
106
        )
107

108
        # With functionalization, it should appear wrapped with with_effects()
109
        gm, gs = aot_export_module(M(), inputs, trace_joint=False)
110
        self.assertExpectedInline(
111
            str(gm.code).strip(),
112
            """\
113
def forward(self, arg0_1, arg1_1):
114
    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo');  arg0_1 = None
115
    getitem = with_effects[0];  with_effects = None
116
    add = torch.ops.aten.add.Tensor(arg1_1, arg1_1);  arg1_1 = None
117
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
118
    getitem_2 = with_effects_1[0];  with_effects_1 = None
119
    return (getitem_2, add)""",
120
        )
121
        self.assertEqual(len(gs.input_tokens), 1)
122
        self.assertEqual(len(gs.output_tokens), 1)
123

124
        with torch._functorch.config.patch(unlift_effect_tokens=True):
125
            gm, gs = aot_export_module(M(), inputs, trace_joint=False)
126
            self.assertExpectedInline(
127
                str(gm.code).strip(),
128
                """\
129
def forward(self, arg1_1):
130
    _make_token_default = torch.ops.prims._make_token.default()
131
    with_effects = torch.ops.higher_order.with_effects(_make_token_default, torch.ops.aten._print.default, 'moo');  _make_token_default = None
132
    getitem = with_effects[0];  with_effects = None
133
    add = torch.ops.aten.add.Tensor(arg1_1, arg1_1);  arg1_1 = None
134
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
135
    getitem_2 = with_effects_1[0];  with_effects_1 = None
136
    _sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_2]);  getitem_2 = _sink_tokens_default = None
137
    return [add]""",  # noqa: B950
138
            )
139

140
    def test_torchbind_custom_op(self):
141
        class M(torch.nn.Module):
142
            def __init__(self) -> None:
143
                super().__init__()
144
                self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
145

146
            def forward(self, x):
147
                return (x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x),)
148

149
        with enable_torchbind_tracing():
150
            gm, gs = aot_export_module(M(), (torch.ones(2, 3),), trace_joint=False)
151

152
        self.assertExpectedInline(
153
            str(gm.code).strip(),
154
            """\
155
def forward(self, arg0_1, arg1_1):
156
    _torchbind_obj0 = self._torchbind_obj0
157
    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TorchScriptTesting.takes_foo.default, _torchbind_obj0, arg1_1);  arg0_1 = _torchbind_obj0 = None
158
    getitem = with_effects[0]
159
    getitem_1 = with_effects[1];  with_effects = None
160
    add = torch.ops.aten.add.Tensor(arg1_1, getitem_1);  arg1_1 = getitem_1 = None
161
    return (getitem, add)""",  # noqa: B950
162
        )
163
        self.assertEqual(len(gs.input_tokens), 1)
164
        self.assertEqual(len(gs.output_tokens), 1)
165

166
    def test_print_with_buffer_mutations(self):
167
        class M(torch.nn.Module):
168
            def __init__(self) -> None:
169
                super().__init__()
170
                self.buf = torch.nn.Buffer(torch.ones(3))
171

172
            def forward(self, x):
173
                torch.ops.aten._print("moo")
174
                res = x + x
175
                self.buf.add_(res)
176
                res = self.buf + x
177
                torch.ops.aten._print("moo")
178
                return (res,)
179

180
        inputs = (torch.randn(3),)
181

182
        # With functionalization, it should appear wrapped with with_effects()
183
        gm, gs = aot_export_module(M(), inputs, trace_joint=False)
184
        self.assertExpectedInline(
185
            str(gm.code).strip(),
186
            """\
187
def forward(self, arg0_1, arg1_1, arg2_1):
188
    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo');  arg0_1 = None
189
    getitem = with_effects[0];  with_effects = None
190
    add = torch.ops.aten.add.Tensor(arg2_1, arg2_1)
191
    add_1 = torch.ops.aten.add.Tensor(arg1_1, add);  arg1_1 = add = None
192
    add_2 = torch.ops.aten.add.Tensor(add_1, arg2_1);  arg2_1 = None
193
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo');  getitem = None
194
    getitem_2 = with_effects_1[0];  with_effects_1 = None
195
    return (getitem_2, add_1, add_2)""",
196
        )
197
        self.assertEqual(len(gs.input_tokens), 1)
198
        self.assertEqual(len(gs.output_tokens), 1)
199
        self.assertEqual(len(gs.buffers_to_mutate), 1)
200

201
    def test_print_with_input_mutations(self):
202
        class M(torch.nn.Module):
203
            def __init__(self) -> None:
204
                super().__init__()
205

206
            def forward(self, x):
207
                torch.ops.aten._print("moo")
208
                res = x + x
209
                x.add_(res)
210
                res = x + x
211
                torch.ops.aten._print("moo")
212
                return (res,)
213

214
        inputs = (torch.randn(3),)
215

216
        # With functionalization, it should appear wrapped with with_effects()
217
        gm, gs = aot_export_module(M(), inputs, trace_joint=False)
218
        self.assertEqual(len(gs.input_tokens), 1)
219
        self.assertEqual(len(gs.output_tokens), 1)
220
        self.assertEqual(len(gs.user_inputs_to_mutate), 1)
221

222
    def test_alias_op(self):
223
        def f(token, x):
224
            token, out = with_effects(token, torch.ops.aten.absolute_.default, x)
225
            return token, out
226

227
        with self.assertRaisesRegex(
228
            AssertionError, r"Ops with aliasing is not supported"
229
        ):
230
            make_fx(f)(torch.tensor([]), torch.tensor(4))
231

232
    def test_compile_aot_eager(self):
233
        def f(x):
234
            torch.ops.aten._print("moo")
235
            res = x + x
236
            torch.ops.aten._print("moo")
237
            return res
238

239
        inputs = (torch.randn(2, 3),)
240

241
        res = torch.compile(f, backend="aot_eager")(*inputs)
242
        self.assertTrue(torch.allclose(res, f(*inputs)))
243

244
    @unittest.skipIf(IS_WINDOWS, "triton")
245
    @unittest.skipIf(not SM70OrLater, "triton")
246
    def test_compile_inductor(self):
247
        def f(x):
248
            torch.ops.aten._print("moo")
249
            res = x + x
250
            torch.ops.aten._print("moo")
251
            return res
252

253
        inputs = (torch.randn(2, 3),)
254

255
        res = torch.compile(f, backend="inductor")(*inputs)
256
        self.assertTrue(torch.allclose(res, f(*inputs)))
257

258
    @unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
259
    @skipIfNoDynamoSupport
260
    def test_compile_inductor_external_op_return_none(self):
261
        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
262
            torch.library.define(
263
                "mylib::inplace_add",
264
                "(Tensor input, Tensor(a!) output) -> ()",
265
                lib=lib,
266
            )
267

268
            def inplace_add(input: torch.Tensor, output: torch.Tensor) -> None:
269
                assert input.device == output.device
270
                output.add_(input)
271

272
            lib.impl("inplace_add", inplace_add, "CompositeExplicitAutograd")
273

274
            def f(x):
275
                out = torch.empty(3)
276
                out = torch.zeros_like(out)
277
                torch.ops.mylib.inplace_add(x, out)
278
                return out
279

280
            inputs = (torch.randn(3),)
281

282
            res = torch.compile(f, backend="inductor")(*inputs)
283
            self.assertTrue(torch.allclose(res, f(*inputs)))
284

285
    def test_compile_aot_eager_requires_grad(self):
286
        def f(x):
287
            torch.ops.aten._print("moo")
288
            res = x + x
289
            torch.ops.aten._print("moo")
290
            return res
291

292
        inputs = (torch.randn(2, 3, requires_grad=True),)
293

294
        res = torch.compile(f, backend="aot_eager")(*inputs)
295
        self.assertTrue(torch.allclose(res, f(*inputs)))
296

297
        res.sum().backward()
298

299
    @unittest.skipIf(IS_WINDOWS, "triton")
300
    @unittest.skipIf(TEST_WITH_ROCM, "triton")
301
    @unittest.skipIf(not SM80OrLater, "triton")
302
    @unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
303
    @unittest.skipIf(not TEST_CUDA, "triton")
304
    @skipIfNoDynamoSupport
305
    def test_register_effectful_custom_op(self):
306
        with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
307
            torch._dynamo.config.capture_scalar_outputs = True
308
            torch._dynamo.config.capture_dynamic_output_shape_ops = True
309

310
            torch.library.define(
311
                "mylib::record_scalar_tensor",
312
                "(Tensor x, str prefix) -> ()",
313
                lib=lib,
314
            )
315

316
            # global variable to store the recorded tensor and prefix.
317
            recorded_dict = {}
318

319
            # Pytorch custorm op implementation
320
            @torch.library.impl(
321
                "mylib::record_scalar_tensor",
322
                "CompositeExplicitAutograd",
323
                lib=lib,
324
            )
325
            def record_scalar_tensor(x, prefix):
326
                recorded_dict[prefix] = x.clone()
327
                return
328

329
            # Meta function of the custom op
330
            @torch.library.impl_abstract(
331
                "mylib::record_scalar_tensor",
332
                lib=lib,
333
            )
334
            def record_scalar_tensor_meta(x, prefix):
335
                return
336

337
            from torch._higher_order_ops.effects import (
338
                _EffectType,
339
                _register_effectful_op,
340
            )
341

342
            _register_effectful_op(
343
                torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
344
            )
345

346
            my_config = {}
347
            my_config["MockModule"] = "mean"
348
            my_config["MockModule.linear"] = "mean"
349
            my_config["MockModule.relu"] = "mean"
350

351
            class MyLinear(torch.nn.Module):
352
                def __init__(self, in_features, out_features):
353
                    super().__init__()
354
                    self.weight = torch.nn.Parameter(
355
                        torch.randn(out_features, in_features), requires_grad=True
356
                    )
357
                    self.bias = torch.nn.Parameter(
358
                        torch.randn(out_features), requires_grad=True
359
                    )
360

361
                def forward(self, x):
362
                    return torch.nn.functional.linear(x, self.weight, self.bias)
363

364
            class MockModule(torch.nn.Module):
365
                def __init__(self) -> None:
366
                    super().__init__()
367
                    self.linear = MyLinear(10, 10)
368
                    self.register_buffer(
369
                        "buf0", torch.randn(10, 10, requires_grad=True)
370
                    )
371

372
                def forward(self, x):
373
                    return torch.nn.functional.relu(self.linear(x) + self.buf0)
374

375
            def forward_hook(
376
                module: torch.nn.Module,
377
                inputs: torch.Tensor,
378
                output: torch.Tensor,
379
                prefix: str,
380
                aggregate_method: str,
381
            ) -> torch.Tensor:
382
                if aggregate_method == "mean":
383
                    torch.ops.mylib.record_scalar_tensor(output.mean(), prefix)
384
                elif aggregate_method == "max":
385
                    torch.ops.mylib.record_scalar_tensor(output.max(), prefix)
386
                else:
387
                    # demo purpose, using "min"
388
                    torch.ops.mylib.record_scalar_tensor(output.sum(), prefix)
389
                return output
390

391
            def add_hooks(module, config):
392
                handles: List[RemovableHandle] = []
393
                q = deque([(module.__class__.__name__, module)])
394
                while q:
395
                    name, m = q.pop()
396
                    children = [(name + "." + n, y) for (n, y) in m.named_children()]
397
                    q.extend(children)
398
                    aggregate_method = config.get(name, "mean")
399
                    prefix = name + ":" + aggregate_method
400
                    handle = m.register_forward_hook(
401
                        partial(
402
                            forward_hook,
403
                            prefix=prefix,
404
                            aggregate_method=aggregate_method,
405
                        )
406
                    )
407
                    if handle:
408
                        handles.append(handle)
409
                return handles
410

411
            x = torch.randn(10, 10, device="cuda")
412
            mod = MockModule().to("cuda")
413

414
            add_hooks(mod, my_config)
415

416
            opt_mod = torch.compile(backend="inductor")(mod)
417
            y = opt_mod(x)
418

419
            self.assertTrue(torch.allclose(y, mod(x)))
420
            # Ensure it works well with backward
421
            y.sum().backward()
422
            # Ensure the grad is existing
423
            self.assertTrue(isinstance(opt_mod.linear.weight.grad, torch.Tensor))
424

425
            self.assertEqual(len(recorded_dict), 2)
426
            self.assertTrue("MockModule.linear:mean" in recorded_dict)
427
            self.assertTrue("MockModule:mean" in recorded_dict)
428

429
    @skipIfNoDynamoSupport
430
    def test_effectful_custom_op_with_subclasses(self):
431
        with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
432
            lib.define("zoo(Tensor x) -> Tensor")
433
            lib.define("zoo2(Tensor x) -> Tensor")
434

435
            d = {"fw": 0, "bw": 0}
436

437
            def reset_counter():
438
                d["fw"] = 0
439
                d["bw"] = 0
440

441
            def assert_counter(fw, bw):
442
                self.assertEqual(d["fw"], fw)
443
                self.assertEqual(d["bw"], bw)
444

445
            def foo_impl(a):
446
                d["fw"] = d["fw"] + 1
447
                return 2 * a.clone()
448

449
            def foo_meta(a):
450
                return a.clone()
451

452
            def foo2_impl(x):
453
                d["bw"] = d["bw"] + 1
454
                return x.clone()
455

456
            def foo2_meta(a):
457
                return a.clone()
458

459
            for backend in ["CPU", "CUDA"]:
460
                lib.impl("zoo", foo_impl, backend)
461
                lib.impl("zoo2", foo2_impl, backend)
462
            lib.impl("zoo", foo_meta, "Meta")
463
            lib.impl("zoo2", foo2_meta, "Meta")
464

465
            def foo_bwd(ctx, grad):
466
                torch.ops._mylib.zoo2(grad)
467
                return grad.clone()
468

469
            torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib)
470

471
            from torch._higher_order_ops.effects import (
472
                _EffectType,
473
                _register_effectful_op,
474
            )
475

476
            _register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED)
477
            _register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED)
478

479
            def fn(x, y):
480
                return torch.ops._mylib.zoo(x) + y
481

482
            def ins_sc():
483
                return (
484
                    TwoTensor(
485
                        torch.tensor([1.0, 2.0, 3.0]), torch.tensor([1.0, 2.0, 3.0])
486
                    ),
487
                    torch.tensor([4.0, 5.0, 6.0]),
488
                )
489

490
            def ins_dense():
491
                return torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])
492

493
            for i, (ins_fn, expected_fw_count) in enumerate(
494
                zip([ins_sc, ins_dense], [2, 1])
495
            ):
496
                reset_counter()
497
                ref_out = fn(*ins_fn())
498
                assert_counter(expected_fw_count, 0)
499

500
                compiled_fn = torch.compile(fn, backend="aot_eager")
501
                out = compiled_fn(*ins_fn())
502
                reset_counter()
503
                out = compiled_fn(*ins_fn())
504
                assert_counter(expected_fw_count, 0)
505

506
                self.assertEqual(ref_out, out)
507

508
            def ins_dense_req_grad():
509
                return (
510
                    torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
511
                    torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
512
                )
513

514
            def ins_sc_req_grad():
515
                return (
516
                    TwoTensor(
517
                        torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
518
                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
519
                    ),
520
                    TwoTensor(
521
                        torch.tensor([7.0, 8.0, 9.0], requires_grad=True),
522
                        torch.tensor([10.0, 11.0, 12.0], requires_grad=True),
523
                    ),
524
                )
525

526
            for i, (
527
                ins_fn_req_grad,
528
                (
529
                    expected_fw_count,
530
                    expected_fw_count_after_bw,
531
                    expected_bw_count_after_bw,
532
                ),
533
            ) in enumerate(
534
                zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1, 1), (2, 2, 2)])
535
            ):
536
                ref_ins = ins_fn_req_grad()
537
                reset_counter()
538
                ref_out = fn(*ref_ins)
539
                assert_counter(expected_fw_count, 0)
540
                ref_out.sum().backward()
541
                assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)
542

543
                compiled_fn = torch.compile(fn, fullgraph=True)
544

545
                ins = ins_fn_req_grad()
546
                out = compiled_fn(*ins)
547
                reset_counter()
548
                out = compiled_fn(*ins)
549
                assert_counter(expected_fw_count, 0)
550
                self.assertEqual(ref_out, out)
551
                out.sum().backward()
552
                assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)
553
                self.assertEqual(ref_ins[1].grad, ins[1].grad)
554
                self.assertEqual(ref_ins[0].grad, ins[0].grad)
555

556
            fw_graph, bw_graph = get_fw_bw_graph(fn, ins_sc_req_grad())
557
            self.assertExpectedInline(
558
                fw_graph.code.strip(),
559
                """\
560
def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5):
561
    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.zoo.default, primals_2);  primals_1 = primals_2 = None
562
    getitem = with_effects[0]
563
    getitem_1 = with_effects[1];  with_effects = None
564
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.zoo.default, primals_3);  getitem = primals_3 = None
565
    getitem_2 = with_effects_1[0]
566
    getitem_3 = with_effects_1[1];  with_effects_1 = None
567
    add = torch.ops.aten.add.Tensor(getitem_1, primals_4);  getitem_1 = primals_4 = None
568
    add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_5);  getitem_3 = primals_5 = None
569
    return (getitem_2, add, add_1)""",
570
            )
571
            self.assertExpectedInline(
572
                bw_graph.code.strip(),
573
                """\
574
def forward(self, tangents_1, tangents_2, tangents_token):
575
    with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.zoo2.default, tangents_1);  tangents_token = None
576
    getitem_4 = with_effects_2[0];  with_effects_2 = None
577
    with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.zoo2.default, tangents_2);  getitem_4 = None
578
    getitem_6 = with_effects_3[0];  with_effects_3 = None
579
    clone = torch.ops.aten.clone.default(tangents_1)
580
    clone_1 = torch.ops.aten.clone.default(tangents_2)
581
    return (clone, clone_1, tangents_1, tangents_2, getitem_6)""",
582
            )
583

584
    def test_effects_and_input_mutation_return(self):
585
        def fn(a, b):
586
            torch.ops.aten._print("effect")
587
            return torch.sin(a, out=b)
588

589
        inp = [torch.randn(3, 3), torch.ones(3, 3)]
590
        ref_out = fn(*inp)
591
        out = torch.compile(fn, fullgraph=True)(*inp)
592
        self.assertEqual(ref_out, out)
593

594
        fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
595
        self.assertExpectedInline(
596
            fw_graph.code.strip(),
597
            """\
598
def forward(self, arg0_1, arg1_1, arg2_1):
599
    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect');  arg0_1 = None
600
    getitem = with_effects[0];  with_effects = None
601
    sin = torch.ops.aten.sin.default(arg1_1);  arg1_1 = None
602
    return (getitem, sin, sin)""",
603
        )
604

605
    def test_effects_and_input_output_view_simple(self):
606
        def fn(a):
607
            return a.view(-1)
608

609
        inp = [torch.ones(2, 2, requires_grad=False).add(1)]
610
        ref_out = fn(*inp)
611
        out = torch.compile(fn, fullgraph=True)(*inp)
612
        self.assertEqual(ref_out, out)
613

614
        inp = [torch.ones(2, 2, requires_grad=True).add(1)]
615
        ref_out = fn(*inp)
616
        out = torch.compile(fn, fullgraph=True)(*inp)
617
        self.assertEqual(ref_out, out)
618

619
        fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
620

621
        self.assertExpectedInline(
622
            fw_graph.code.strip(),
623
            """\
624
def forward(self, arg0_1):
625
    view = torch.ops.aten.view.default(arg0_1, [-1]);  arg0_1 = None
626
    return (view,)""",
627
        )
628

629
    def test_effects_and_aliased_outputs(self):
630
        def fn(a):
631
            b = a.mul(2)
632
            torch.ops.aten._print("effect")
633
            c = b.view(-1)
634
            return b, c
635

636
        f_compiled = aot_function(fn, nop)
637
        for req_grad in [True, False]:
638
            inp = torch.ones(3, requires_grad=req_grad)
639
            out_ref = fn(inp)
640
            out_test = f_compiled(inp)
641
            self.assertEqual(out_ref[0], out_test[0])
642
            self.assertEqual(out_ref[1], out_test[1])
643
            # Try mutating one of the outputs, which is aliased.
644
            out_ref[0].mul_(3)
645
            out_test[0].mul_(3)
646
            # Assert that the aliasing relationship was preserved
647
            self.assertEqual(out_ref[0], out_test[0])
648
            self.assertEqual(out_ref[1], out_test[1])
649

650
    def test_effects_and_input_mutation_is_output(self):
651
        def fn(a):
652
            a.mul_(2)
653
            torch.ops.aten._print("effect")
654
            return a
655

656
        inp = make_inputs_non_leaves([torch.ones(3, 3, requires_grad=True)])
657
        ref_out = fn(*inp)
658
        out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp)
659
        self.assertEqual(ref_out, out)
660

661
        inp = [torch.ones(3, 3, requires_grad=False)]
662
        ref_out = fn(*inp)
663
        out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp)
664
        self.assertEqual(ref_out, out)
665

666
        fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
667
        self.assertExpectedInline(
668
            fw_graph.code.strip(),
669
            """\
670
def forward(self, arg0_1, arg1_1):
671
    mul = torch.ops.aten.mul.Tensor(arg1_1, 2);  arg1_1 = None
672
    with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect');  arg0_1 = None
673
    getitem = with_effects[0];  with_effects = None
674
    return (getitem, mul, mul)""",
675
        )
676

677
    @skipIfTorchDynamo()
678
    def test_effectful_op_in_backward(self):
679
        with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
680
            lib.define("foo(Tensor x) -> Tensor")
681

682
            def foo_impl(a):
683
                return a.clone()
684

685
            def foo_bwd(ctx, grad):
686
                return torch.ops._mylib.foo(grad)
687

688
            for backend in ["CPU", "CUDA", "Meta"]:
689
                lib.impl("foo", foo_impl, backend)
690

691
            torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib)
692

693
            from torch._higher_order_ops.effects import (
694
                _deregister_effectful_op,
695
                _EffectType,
696
                _register_effectful_op,
697
            )
698

699
            _register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
700
            try:
701

702
                def fn(x, y):
703
                    return torch.ops._mylib.foo(x) + y
704

705
                def ins_dense_req_grad():
706
                    return (
707
                        torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
708
                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
709
                    )
710

711
                def ins_sc_req_grad():
712
                    return (
713
                        TwoTensor(
714
                            torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
715
                            torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
716
                        ),
717
                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
718
                    )
719

720
                for i, ins_fn in enumerate([ins_dense_req_grad, ins_sc_req_grad]):
721
                    ref_ins = ins_fn()
722

723
                    ref_out = fn(*ref_ins)
724
                    ref_out.sum().backward()
725

726
                    compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True)
727
                    ins = ins_fn()
728
                    out = compiled_fn(*ins)
729
                    self.assertEqual(ref_out, out)
730
                    out.sum().backward()
731
                    self.assertEqual(ref_ins[1].grad, ins[1].grad)
732
                    self.assertEqual(ref_ins[0].grad, ins[0].grad)
733

734
                    fw_graph, bw_graph = get_fw_bw_graph(fn, ins)
735
                    if i == 0:
736
                        self.assertExpectedInline(
737
                            fw_graph.code.strip(),
738
                            """\
739
def forward(self, primals_1, primals_2, primals_3):
740
    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2);  primals_1 = primals_2 = None
741
    getitem = with_effects[0]
742
    getitem_1 = with_effects[1];  with_effects = None
743
    add = torch.ops.aten.add.Tensor(getitem_1, primals_3);  getitem_1 = primals_3 = None
744
    return (getitem, add)""",
745
                        )
746
                        self.assertExpectedInline(
747
                            bw_graph.code.strip(),
748
                            """\
749
def forward(self, tangents_1, tangents_token):
750
    with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1);  tangents_token = None
751
    getitem_2 = with_effects_1[0]
752
    getitem_3 = with_effects_1[1];  with_effects_1 = None
753
    return (getitem_3, tangents_1, getitem_2)""",
754
                        )
755
                    elif i == 1:
756
                        self.assertExpectedInline(
757
                            fw_graph.code.strip(),
758
                            """\
759
def forward(self, primals_1, primals_2, primals_3, primals_4):
760
    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2);  primals_1 = primals_2 = None
761
    getitem = with_effects[0]
762
    getitem_1 = with_effects[1];  with_effects = None
763
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.foo.default, primals_3);  getitem = primals_3 = None
764
    getitem_2 = with_effects_1[0]
765
    getitem_3 = with_effects_1[1];  with_effects_1 = None
766
    add = torch.ops.aten.add.Tensor(getitem_1, primals_4);  getitem_1 = None
767
    add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_4);  getitem_3 = primals_4 = None
768
    return (getitem_2, add, add_1)""",
769
                        )
770
                        self.assertExpectedInline(
771
                            bw_graph.code.strip(),
772
                            """\
773
def forward(self, tangents_1, tangents_2, tangents_token):
774
    with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1);  tangents_token = None
775
    getitem_4 = with_effects_2[0]
776
    getitem_5 = with_effects_2[1];  with_effects_2 = None
777
    with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.foo.default, tangents_2);  getitem_4 = None
778
    getitem_6 = with_effects_3[0]
779
    getitem_7 = with_effects_3[1];  with_effects_3 = None
780
    return (getitem_5, getitem_7, tangents_1, tangents_2, getitem_6)""",
781
                        )
782
                    else:
783
                        raise NotImplementedError
784
            finally:
785
                _deregister_effectful_op(torch.ops._mylib.foo.default)
786

787
    @skipIfNoDynamoSupport
788
    def test_regular_effectful_op_only_in_backward(self):
789
        from torch._higher_order_ops.effects import (
790
            _deregister_effectful_op,
791
            _EffectType,
792
            _register_effectful_op,
793
        )
794

795
        _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
796
        try:
797

798
            def fn(x):
799
                return x.sin()
800

801
            def inps_fn():
802
                return (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)
803

804
            torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn())
805

806
            fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn())
807
            self.assertExpectedInline(
808
                fw_graph.code.strip(),
809
                """\
810
def forward(self, primals_1):
811
    sin = torch.ops.aten.sin.default(primals_1)
812
    return (sin, primals_1)""",
813
            )
814
            self.assertExpectedInline(
815
                bw_graph.code.strip(),
816
                """\
817
def forward(self, primals_1, tangents_1, tangents_token):
818
    with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1);  tangents_token = primals_1 = None
819
    getitem = with_effects[0]
820
    getitem_1 = with_effects[1];  with_effects = None
821
    mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1);  tangents_1 = getitem_1 = None
822
    return (mul, getitem)""",
823
            )
824

825
            def inps_fn_sc():
826
                return (
827
                    TwoTensor(
828
                        torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
829
                        torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
830
                    ),
831
                )
832

833
            torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn_sc())
834
            fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn_sc())
835
            self.assertExpectedInline(
836
                fw_graph.code.strip(),
837
                """\
838
def forward(self, primals_1, primals_2):
839
    sin = torch.ops.aten.sin.default(primals_1)
840
    sin_1 = torch.ops.aten.sin.default(primals_2)
841
    return (sin, sin_1, primals_1, primals_2)""",
842
            )
843
            self.assertExpectedInline(
844
                bw_graph.code.strip(),
845
                """\
846
def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token):
847
    with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1);  tangents_token = primals_1 = None
848
    getitem = with_effects[0]
849
    getitem_1 = with_effects[1];  with_effects = None
850
    with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten.cos.default, primals_2);  getitem = primals_2 = None
851
    getitem_2 = with_effects_1[0]
852
    getitem_3 = with_effects_1[1];  with_effects_1 = None
853
    mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1);  tangents_1 = getitem_1 = None
854
    mul_1 = torch.ops.aten.mul.Tensor(tangents_2, getitem_3);  tangents_2 = getitem_3 = None
855
    return (mul, mul_1, getitem_2)""",
856
            )
857
        finally:
858
            _deregister_effectful_op(torch.ops.aten.cos.default)
859

860
    @skipIfNoDynamoSupport
861
    def test_regular_effectful_op_in_forward_and_backward(self):
862
        from torch._higher_order_ops.effects import (
863
            _deregister_effectful_op,
864
            _EffectType,
865
            _register_effectful_op,
866
        )
867

868
        _register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
869
        try:
870

871
            def fn(x):
872
                x = x.cos()
873
                return x.sin()
874

875
            inps = (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)
876
            torch.compile(fn, backend="inductor", fullgraph=True)(*inps)
877

878
            fw_graph, bw_graph = get_fw_bw_graph(fn, inps)
879
            self.assertExpectedInline(
880
                fw_graph.code.strip(),
881
                """\
882
def forward(self, primals_1, primals_2):
883
    with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.aten.cos.default, primals_2);  primals_1 = None
884
    getitem = with_effects[0]
885
    getitem_1 = with_effects[1];  with_effects = None
886
    sin = torch.ops.aten.sin.default(getitem_1)
887
    return (getitem, sin, primals_2, getitem_1)""",
888
            )
889
            self.assertExpectedInline(
890
                bw_graph.code.strip(),
891
                """\
892
def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
893
    with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, getitem_1);  tangents_token = getitem_1 = None
894
    getitem_2 = with_effects_1[0]
895
    getitem_3 = with_effects_1[1];  with_effects_1 = None
896
    mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_3);  tangents_1 = getitem_3 = None
897
    sin_1 = torch.ops.aten.sin.default(primals_2);  primals_2 = None
898
    neg = torch.ops.aten.neg.default(sin_1);  sin_1 = None
899
    mul_1 = torch.ops.aten.mul.Tensor(mul, neg);  mul = neg = None
900
    return (mul_1, getitem_2)""",
901
            )
902
        finally:
903
            _deregister_effectful_op(torch.ops.aten.cos.default)
904

905

906
if __name__ == "__main__":
907
    run_tests()
908

Использование cookies

Мы используем файлы cookie в соответствии с Политикой конфиденциальности и Политикой использования cookies.

Нажимая кнопку «Принимаю», Вы даете АО «СберТех» согласие на обработку Ваших персональных данных в целях совершенствования нашего веб-сайта и Сервиса GitVerse, а также повышения удобства их использования.

Запретить использование cookies Вы можете самостоятельно в настройках Вашего браузера.