pytorch
907 строк · 34.6 Кб
1# Owner(s): ["module: functorch"]
2# flake8: noqa: B950
3import unittest
4from collections import deque
5from functools import partial
6from typing import List, TYPE_CHECKING
7
8import torch
9import torch._dynamo
10import torch._functorch
11import torch._inductor
12import torch._inductor.decomposition
13from functorch.compile import (
14aot_function,
15default_decompositions,
16min_cut_rematerialization_partition,
17nop,
18)
19from torch._functorch.aot_autograd import aot_export_module
20from torch._higher_order_ops.effects import with_effects
21from torch._higher_order_ops.torchbind import enable_torchbind_tracing
22from torch.fx.experimental.proxy_tensor import make_fx
23from torch.testing import FileCheck
24from torch.testing._internal.common_cuda import (
25_get_torch_cuda_version,
26SM70OrLater,
27SM80OrLater,
28)
29from torch.testing._internal.common_quantization import skipIfNoDynamoSupport
30from torch.testing._internal.common_utils import (
31IS_WINDOWS,
32run_tests,
33skipIfTorchDynamo,
34TEST_CUDA,
35TEST_WITH_ROCM,
36TestCase,
37)
38from torch.testing._internal.torchbind_impls import init_torchbind_implementations
39
40
41if TYPE_CHECKING:
42from torch.utils.hooks import RemovableHandle
43
44from torch.testing._internal.two_tensor import TwoTensor
45
46
47def extract_graph(fx_g, _, graph_cell):
48graph_cell[0] = fx_g
49return fx_g
50
51
52def get_fw_bw_graph(
53f, inps, partitioner=min_cut_rematerialization_partition, dynamic=False
54):
55fw_graph_cell = [None]
56bw_graph_cell = [None]
57requires_grad = False
58
59def fn_req_grad(t):
60nonlocal requires_grad
61requires_grad = requires_grad or t.requires_grad
62return t
63
64torch.utils._pytree.tree_map_only(torch.Tensor, fn_req_grad, inps)
65
66out = aot_function(
67f,
68fw_compiler=partial(extract_graph, graph_cell=fw_graph_cell),
69bw_compiler=partial(extract_graph, graph_cell=bw_graph_cell)
70if requires_grad
71else nop,
72partition_fn=partitioner,
73decompositions=default_decompositions,
74dynamic=dynamic,
75)(*inps)
76
77if requires_grad:
78out.sum().backward()
79
80return (fw_graph_cell[0], bw_graph_cell[0])
81
82
83def make_inputs_non_leaves(inps):
84return torch.utils._pytree.tree_map_only(torch.Tensor, lambda t: t.add(1), inps)
85
86
87@unittest.skipIf(not torch._dynamo.is_dynamo_supported(), "dynamo isn't support")
88class TestWithEffects(TestCase):
89def setUp(self):
90init_torchbind_implementations()
91
92def test_print(self):
93class M(torch.nn.Module):
94def forward(self, x):
95torch.ops.aten._print("moo")
96res = x + x
97torch.ops.aten._print("moo")
98return (res,)
99
100inputs = (torch.randn(3),)
101
102# Without functionalization, print should just appear in the graph directly
103gm = make_fx(M())(*inputs)
104FileCheck().check_count("torch.ops.aten._print.default", 2, exactly=True).run(
105gm.code
106)
107
108# With functionalization, it should appear wrapped with with_effects()
109gm, gs = aot_export_module(M(), inputs, trace_joint=False)
110self.assertExpectedInline(
111str(gm.code).strip(),
112"""\
113def forward(self, arg0_1, arg1_1):
114with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
115getitem = with_effects[0]; with_effects = None
116add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
117with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
118getitem_2 = with_effects_1[0]; with_effects_1 = None
119return (getitem_2, add)""",
120)
121self.assertEqual(len(gs.input_tokens), 1)
122self.assertEqual(len(gs.output_tokens), 1)
123
124with torch._functorch.config.patch(unlift_effect_tokens=True):
125gm, gs = aot_export_module(M(), inputs, trace_joint=False)
126self.assertExpectedInline(
127str(gm.code).strip(),
128"""\
129def forward(self, arg1_1):
130_make_token_default = torch.ops.prims._make_token.default()
131with_effects = torch.ops.higher_order.with_effects(_make_token_default, torch.ops.aten._print.default, 'moo'); _make_token_default = None
132getitem = with_effects[0]; with_effects = None
133add = torch.ops.aten.add.Tensor(arg1_1, arg1_1); arg1_1 = None
134with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
135getitem_2 = with_effects_1[0]; with_effects_1 = None
136_sink_tokens_default = torch.ops.prims._sink_tokens.default([getitem_2]); getitem_2 = _sink_tokens_default = None
137return [add]""", # noqa: B950
138)
139
140def test_torchbind_custom_op(self):
141class M(torch.nn.Module):
142def __init__(self) -> None:
143super().__init__()
144self.attr = torch.classes._TorchScriptTesting._Foo(10, 20)
145
146def forward(self, x):
147return (x + torch.ops._TorchScriptTesting.takes_foo(self.attr, x),)
148
149with enable_torchbind_tracing():
150gm, gs = aot_export_module(M(), (torch.ones(2, 3),), trace_joint=False)
151
152self.assertExpectedInline(
153str(gm.code).strip(),
154"""\
155def forward(self, arg0_1, arg1_1):
156_torchbind_obj0 = self._torchbind_obj0
157with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops._TorchScriptTesting.takes_foo.default, _torchbind_obj0, arg1_1); arg0_1 = _torchbind_obj0 = None
158getitem = with_effects[0]
159getitem_1 = with_effects[1]; with_effects = None
160add = torch.ops.aten.add.Tensor(arg1_1, getitem_1); arg1_1 = getitem_1 = None
161return (getitem, add)""", # noqa: B950
162)
163self.assertEqual(len(gs.input_tokens), 1)
164self.assertEqual(len(gs.output_tokens), 1)
165
166def test_print_with_buffer_mutations(self):
167class M(torch.nn.Module):
168def __init__(self) -> None:
169super().__init__()
170self.buf = torch.nn.Buffer(torch.ones(3))
171
172def forward(self, x):
173torch.ops.aten._print("moo")
174res = x + x
175self.buf.add_(res)
176res = self.buf + x
177torch.ops.aten._print("moo")
178return (res,)
179
180inputs = (torch.randn(3),)
181
182# With functionalization, it should appear wrapped with with_effects()
183gm, gs = aot_export_module(M(), inputs, trace_joint=False)
184self.assertExpectedInline(
185str(gm.code).strip(),
186"""\
187def forward(self, arg0_1, arg1_1, arg2_1):
188with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'moo'); arg0_1 = None
189getitem = with_effects[0]; with_effects = None
190add = torch.ops.aten.add.Tensor(arg2_1, arg2_1)
191add_1 = torch.ops.aten.add.Tensor(arg1_1, add); arg1_1 = add = None
192add_2 = torch.ops.aten.add.Tensor(add_1, arg2_1); arg2_1 = None
193with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten._print.default, 'moo'); getitem = None
194getitem_2 = with_effects_1[0]; with_effects_1 = None
195return (getitem_2, add_1, add_2)""",
196)
197self.assertEqual(len(gs.input_tokens), 1)
198self.assertEqual(len(gs.output_tokens), 1)
199self.assertEqual(len(gs.buffers_to_mutate), 1)
200
201def test_print_with_input_mutations(self):
202class M(torch.nn.Module):
203def __init__(self) -> None:
204super().__init__()
205
206def forward(self, x):
207torch.ops.aten._print("moo")
208res = x + x
209x.add_(res)
210res = x + x
211torch.ops.aten._print("moo")
212return (res,)
213
214inputs = (torch.randn(3),)
215
216# With functionalization, it should appear wrapped with with_effects()
217gm, gs = aot_export_module(M(), inputs, trace_joint=False)
218self.assertEqual(len(gs.input_tokens), 1)
219self.assertEqual(len(gs.output_tokens), 1)
220self.assertEqual(len(gs.user_inputs_to_mutate), 1)
221
222def test_alias_op(self):
223def f(token, x):
224token, out = with_effects(token, torch.ops.aten.absolute_.default, x)
225return token, out
226
227with self.assertRaisesRegex(
228AssertionError, r"Ops with aliasing is not supported"
229):
230make_fx(f)(torch.tensor([]), torch.tensor(4))
231
232def test_compile_aot_eager(self):
233def f(x):
234torch.ops.aten._print("moo")
235res = x + x
236torch.ops.aten._print("moo")
237return res
238
239inputs = (torch.randn(2, 3),)
240
241res = torch.compile(f, backend="aot_eager")(*inputs)
242self.assertTrue(torch.allclose(res, f(*inputs)))
243
244@unittest.skipIf(IS_WINDOWS, "triton")
245@unittest.skipIf(not SM70OrLater, "triton")
246def test_compile_inductor(self):
247def f(x):
248torch.ops.aten._print("moo")
249res = x + x
250torch.ops.aten._print("moo")
251return res
252
253inputs = (torch.randn(2, 3),)
254
255res = torch.compile(f, backend="inductor")(*inputs)
256self.assertTrue(torch.allclose(res, f(*inputs)))
257
258@unittest.skipIf(IS_WINDOWS, "Skipped on Windows!")
259@skipIfNoDynamoSupport
260def test_compile_inductor_external_op_return_none(self):
261with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
262torch.library.define(
263"mylib::inplace_add",
264"(Tensor input, Tensor(a!) output) -> ()",
265lib=lib,
266)
267
268def inplace_add(input: torch.Tensor, output: torch.Tensor) -> None:
269assert input.device == output.device
270output.add_(input)
271
272lib.impl("inplace_add", inplace_add, "CompositeExplicitAutograd")
273
274def f(x):
275out = torch.empty(3)
276out = torch.zeros_like(out)
277torch.ops.mylib.inplace_add(x, out)
278return out
279
280inputs = (torch.randn(3),)
281
282res = torch.compile(f, backend="inductor")(*inputs)
283self.assertTrue(torch.allclose(res, f(*inputs)))
284
285def test_compile_aot_eager_requires_grad(self):
286def f(x):
287torch.ops.aten._print("moo")
288res = x + x
289torch.ops.aten._print("moo")
290return res
291
292inputs = (torch.randn(2, 3, requires_grad=True),)
293
294res = torch.compile(f, backend="aot_eager")(*inputs)
295self.assertTrue(torch.allclose(res, f(*inputs)))
296
297res.sum().backward()
298
299@unittest.skipIf(IS_WINDOWS, "triton")
300@unittest.skipIf(TEST_WITH_ROCM, "triton")
301@unittest.skipIf(not SM80OrLater, "triton")
302@unittest.skipIf(_get_torch_cuda_version() >= (11, 7), "triton")
303@unittest.skipIf(not TEST_CUDA, "triton")
304@skipIfNoDynamoSupport
305def test_register_effectful_custom_op(self):
306with torch.library._scoped_library("mylib", "FRAGMENT") as lib:
307torch._dynamo.config.capture_scalar_outputs = True
308torch._dynamo.config.capture_dynamic_output_shape_ops = True
309
310torch.library.define(
311"mylib::record_scalar_tensor",
312"(Tensor x, str prefix) -> ()",
313lib=lib,
314)
315
316# global variable to store the recorded tensor and prefix.
317recorded_dict = {}
318
319# Pytorch custorm op implementation
320@torch.library.impl(
321"mylib::record_scalar_tensor",
322"CompositeExplicitAutograd",
323lib=lib,
324)
325def record_scalar_tensor(x, prefix):
326recorded_dict[prefix] = x.clone()
327return
328
329# Meta function of the custom op
330@torch.library.impl_abstract(
331"mylib::record_scalar_tensor",
332lib=lib,
333)
334def record_scalar_tensor_meta(x, prefix):
335return
336
337from torch._higher_order_ops.effects import (
338_EffectType,
339_register_effectful_op,
340)
341
342_register_effectful_op(
343torch.ops.mylib.record_scalar_tensor.default, _EffectType.ORDERED
344)
345
346my_config = {}
347my_config["MockModule"] = "mean"
348my_config["MockModule.linear"] = "mean"
349my_config["MockModule.relu"] = "mean"
350
351class MyLinear(torch.nn.Module):
352def __init__(self, in_features, out_features):
353super().__init__()
354self.weight = torch.nn.Parameter(
355torch.randn(out_features, in_features), requires_grad=True
356)
357self.bias = torch.nn.Parameter(
358torch.randn(out_features), requires_grad=True
359)
360
361def forward(self, x):
362return torch.nn.functional.linear(x, self.weight, self.bias)
363
364class MockModule(torch.nn.Module):
365def __init__(self) -> None:
366super().__init__()
367self.linear = MyLinear(10, 10)
368self.register_buffer(
369"buf0", torch.randn(10, 10, requires_grad=True)
370)
371
372def forward(self, x):
373return torch.nn.functional.relu(self.linear(x) + self.buf0)
374
375def forward_hook(
376module: torch.nn.Module,
377inputs: torch.Tensor,
378output: torch.Tensor,
379prefix: str,
380aggregate_method: str,
381) -> torch.Tensor:
382if aggregate_method == "mean":
383torch.ops.mylib.record_scalar_tensor(output.mean(), prefix)
384elif aggregate_method == "max":
385torch.ops.mylib.record_scalar_tensor(output.max(), prefix)
386else:
387# demo purpose, using "min"
388torch.ops.mylib.record_scalar_tensor(output.sum(), prefix)
389return output
390
391def add_hooks(module, config):
392handles: List[RemovableHandle] = []
393q = deque([(module.__class__.__name__, module)])
394while q:
395name, m = q.pop()
396children = [(name + "." + n, y) for (n, y) in m.named_children()]
397q.extend(children)
398aggregate_method = config.get(name, "mean")
399prefix = name + ":" + aggregate_method
400handle = m.register_forward_hook(
401partial(
402forward_hook,
403prefix=prefix,
404aggregate_method=aggregate_method,
405)
406)
407if handle:
408handles.append(handle)
409return handles
410
411x = torch.randn(10, 10, device="cuda")
412mod = MockModule().to("cuda")
413
414add_hooks(mod, my_config)
415
416opt_mod = torch.compile(backend="inductor")(mod)
417y = opt_mod(x)
418
419self.assertTrue(torch.allclose(y, mod(x)))
420# Ensure it works well with backward
421y.sum().backward()
422# Ensure the grad is existing
423self.assertTrue(isinstance(opt_mod.linear.weight.grad, torch.Tensor))
424
425self.assertEqual(len(recorded_dict), 2)
426self.assertTrue("MockModule.linear:mean" in recorded_dict)
427self.assertTrue("MockModule:mean" in recorded_dict)
428
429@skipIfNoDynamoSupport
430def test_effectful_custom_op_with_subclasses(self):
431with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
432lib.define("zoo(Tensor x) -> Tensor")
433lib.define("zoo2(Tensor x) -> Tensor")
434
435d = {"fw": 0, "bw": 0}
436
437def reset_counter():
438d["fw"] = 0
439d["bw"] = 0
440
441def assert_counter(fw, bw):
442self.assertEqual(d["fw"], fw)
443self.assertEqual(d["bw"], bw)
444
445def foo_impl(a):
446d["fw"] = d["fw"] + 1
447return 2 * a.clone()
448
449def foo_meta(a):
450return a.clone()
451
452def foo2_impl(x):
453d["bw"] = d["bw"] + 1
454return x.clone()
455
456def foo2_meta(a):
457return a.clone()
458
459for backend in ["CPU", "CUDA"]:
460lib.impl("zoo", foo_impl, backend)
461lib.impl("zoo2", foo2_impl, backend)
462lib.impl("zoo", foo_meta, "Meta")
463lib.impl("zoo2", foo2_meta, "Meta")
464
465def foo_bwd(ctx, grad):
466torch.ops._mylib.zoo2(grad)
467return grad.clone()
468
469torch.library.register_autograd("_mylib::zoo", foo_bwd, lib=lib)
470
471from torch._higher_order_ops.effects import (
472_EffectType,
473_register_effectful_op,
474)
475
476_register_effectful_op(torch.ops._mylib.zoo.default, _EffectType.ORDERED)
477_register_effectful_op(torch.ops._mylib.zoo2.default, _EffectType.ORDERED)
478
479def fn(x, y):
480return torch.ops._mylib.zoo(x) + y
481
482def ins_sc():
483return (
484TwoTensor(
485torch.tensor([1.0, 2.0, 3.0]), torch.tensor([1.0, 2.0, 3.0])
486),
487torch.tensor([4.0, 5.0, 6.0]),
488)
489
490def ins_dense():
491return torch.tensor([1.0, 2.0, 3.0]), torch.tensor([4.0, 5.0, 6.0])
492
493for i, (ins_fn, expected_fw_count) in enumerate(
494zip([ins_sc, ins_dense], [2, 1])
495):
496reset_counter()
497ref_out = fn(*ins_fn())
498assert_counter(expected_fw_count, 0)
499
500compiled_fn = torch.compile(fn, backend="aot_eager")
501out = compiled_fn(*ins_fn())
502reset_counter()
503out = compiled_fn(*ins_fn())
504assert_counter(expected_fw_count, 0)
505
506self.assertEqual(ref_out, out)
507
508def ins_dense_req_grad():
509return (
510torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
511torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
512)
513
514def ins_sc_req_grad():
515return (
516TwoTensor(
517torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
518torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
519),
520TwoTensor(
521torch.tensor([7.0, 8.0, 9.0], requires_grad=True),
522torch.tensor([10.0, 11.0, 12.0], requires_grad=True),
523),
524)
525
526for i, (
527ins_fn_req_grad,
528(
529expected_fw_count,
530expected_fw_count_after_bw,
531expected_bw_count_after_bw,
532),
533) in enumerate(
534zip([ins_dense_req_grad, ins_sc_req_grad], [(1, 1, 1), (2, 2, 2)])
535):
536ref_ins = ins_fn_req_grad()
537reset_counter()
538ref_out = fn(*ref_ins)
539assert_counter(expected_fw_count, 0)
540ref_out.sum().backward()
541assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)
542
543compiled_fn = torch.compile(fn, fullgraph=True)
544
545ins = ins_fn_req_grad()
546out = compiled_fn(*ins)
547reset_counter()
548out = compiled_fn(*ins)
549assert_counter(expected_fw_count, 0)
550self.assertEqual(ref_out, out)
551out.sum().backward()
552assert_counter(expected_fw_count_after_bw, expected_bw_count_after_bw)
553self.assertEqual(ref_ins[1].grad, ins[1].grad)
554self.assertEqual(ref_ins[0].grad, ins[0].grad)
555
556fw_graph, bw_graph = get_fw_bw_graph(fn, ins_sc_req_grad())
557self.assertExpectedInline(
558fw_graph.code.strip(),
559"""\
560def forward(self, primals_1, primals_2, primals_3, primals_4, primals_5):
561with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.zoo.default, primals_2); primals_1 = primals_2 = None
562getitem = with_effects[0]
563getitem_1 = with_effects[1]; with_effects = None
564with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.zoo.default, primals_3); getitem = primals_3 = None
565getitem_2 = with_effects_1[0]
566getitem_3 = with_effects_1[1]; with_effects_1 = None
567add = torch.ops.aten.add.Tensor(getitem_1, primals_4); getitem_1 = primals_4 = None
568add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_5); getitem_3 = primals_5 = None
569return (getitem_2, add, add_1)""",
570)
571self.assertExpectedInline(
572bw_graph.code.strip(),
573"""\
574def forward(self, tangents_1, tangents_2, tangents_token):
575with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.zoo2.default, tangents_1); tangents_token = None
576getitem_4 = with_effects_2[0]; with_effects_2 = None
577with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.zoo2.default, tangents_2); getitem_4 = None
578getitem_6 = with_effects_3[0]; with_effects_3 = None
579clone = torch.ops.aten.clone.default(tangents_1)
580clone_1 = torch.ops.aten.clone.default(tangents_2)
581return (clone, clone_1, tangents_1, tangents_2, getitem_6)""",
582)
583
584def test_effects_and_input_mutation_return(self):
585def fn(a, b):
586torch.ops.aten._print("effect")
587return torch.sin(a, out=b)
588
589inp = [torch.randn(3, 3), torch.ones(3, 3)]
590ref_out = fn(*inp)
591out = torch.compile(fn, fullgraph=True)(*inp)
592self.assertEqual(ref_out, out)
593
594fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
595self.assertExpectedInline(
596fw_graph.code.strip(),
597"""\
598def forward(self, arg0_1, arg1_1, arg2_1):
599with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect'); arg0_1 = None
600getitem = with_effects[0]; with_effects = None
601sin = torch.ops.aten.sin.default(arg1_1); arg1_1 = None
602return (getitem, sin, sin)""",
603)
604
605def test_effects_and_input_output_view_simple(self):
606def fn(a):
607return a.view(-1)
608
609inp = [torch.ones(2, 2, requires_grad=False).add(1)]
610ref_out = fn(*inp)
611out = torch.compile(fn, fullgraph=True)(*inp)
612self.assertEqual(ref_out, out)
613
614inp = [torch.ones(2, 2, requires_grad=True).add(1)]
615ref_out = fn(*inp)
616out = torch.compile(fn, fullgraph=True)(*inp)
617self.assertEqual(ref_out, out)
618
619fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
620
621self.assertExpectedInline(
622fw_graph.code.strip(),
623"""\
624def forward(self, arg0_1):
625view = torch.ops.aten.view.default(arg0_1, [-1]); arg0_1 = None
626return (view,)""",
627)
628
629def test_effects_and_aliased_outputs(self):
630def fn(a):
631b = a.mul(2)
632torch.ops.aten._print("effect")
633c = b.view(-1)
634return b, c
635
636f_compiled = aot_function(fn, nop)
637for req_grad in [True, False]:
638inp = torch.ones(3, requires_grad=req_grad)
639out_ref = fn(inp)
640out_test = f_compiled(inp)
641self.assertEqual(out_ref[0], out_test[0])
642self.assertEqual(out_ref[1], out_test[1])
643# Try mutating one of the outputs, which is aliased.
644out_ref[0].mul_(3)
645out_test[0].mul_(3)
646# Assert that the aliasing relationship was preserved
647self.assertEqual(out_ref[0], out_test[0])
648self.assertEqual(out_ref[1], out_test[1])
649
650def test_effects_and_input_mutation_is_output(self):
651def fn(a):
652a.mul_(2)
653torch.ops.aten._print("effect")
654return a
655
656inp = make_inputs_non_leaves([torch.ones(3, 3, requires_grad=True)])
657ref_out = fn(*inp)
658out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp)
659self.assertEqual(ref_out, out)
660
661inp = [torch.ones(3, 3, requires_grad=False)]
662ref_out = fn(*inp)
663out = torch.compile(fn, backend="aot_eager", fullgraph=True)(*inp)
664self.assertEqual(ref_out, out)
665
666fw_graph, bw_graph = get_fw_bw_graph(fn, inp)
667self.assertExpectedInline(
668fw_graph.code.strip(),
669"""\
670def forward(self, arg0_1, arg1_1):
671mul = torch.ops.aten.mul.Tensor(arg1_1, 2); arg1_1 = None
672with_effects = torch.ops.higher_order.with_effects(arg0_1, torch.ops.aten._print.default, 'effect'); arg0_1 = None
673getitem = with_effects[0]; with_effects = None
674return (getitem, mul, mul)""",
675)
676
677@skipIfTorchDynamo()
678def test_effectful_op_in_backward(self):
679with torch.library._scoped_library("_mylib", "FRAGMENT") as lib:
680lib.define("foo(Tensor x) -> Tensor")
681
682def foo_impl(a):
683return a.clone()
684
685def foo_bwd(ctx, grad):
686return torch.ops._mylib.foo(grad)
687
688for backend in ["CPU", "CUDA", "Meta"]:
689lib.impl("foo", foo_impl, backend)
690
691torch.library.register_autograd("_mylib::foo", foo_bwd, lib=lib)
692
693from torch._higher_order_ops.effects import (
694_deregister_effectful_op,
695_EffectType,
696_register_effectful_op,
697)
698
699_register_effectful_op(torch.ops._mylib.foo.default, _EffectType.ORDERED)
700try:
701
702def fn(x, y):
703return torch.ops._mylib.foo(x) + y
704
705def ins_dense_req_grad():
706return (
707torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
708torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
709)
710
711def ins_sc_req_grad():
712return (
713TwoTensor(
714torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
715torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
716),
717torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
718)
719
720for i, ins_fn in enumerate([ins_dense_req_grad, ins_sc_req_grad]):
721ref_ins = ins_fn()
722
723ref_out = fn(*ref_ins)
724ref_out.sum().backward()
725
726compiled_fn = torch.compile(fn, backend="inductor", fullgraph=True)
727ins = ins_fn()
728out = compiled_fn(*ins)
729self.assertEqual(ref_out, out)
730out.sum().backward()
731self.assertEqual(ref_ins[1].grad, ins[1].grad)
732self.assertEqual(ref_ins[0].grad, ins[0].grad)
733
734fw_graph, bw_graph = get_fw_bw_graph(fn, ins)
735if i == 0:
736self.assertExpectedInline(
737fw_graph.code.strip(),
738"""\
739def forward(self, primals_1, primals_2, primals_3):
740with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2); primals_1 = primals_2 = None
741getitem = with_effects[0]
742getitem_1 = with_effects[1]; with_effects = None
743add = torch.ops.aten.add.Tensor(getitem_1, primals_3); getitem_1 = primals_3 = None
744return (getitem, add)""",
745)
746self.assertExpectedInline(
747bw_graph.code.strip(),
748"""\
749def forward(self, tangents_1, tangents_token):
750with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1); tangents_token = None
751getitem_2 = with_effects_1[0]
752getitem_3 = with_effects_1[1]; with_effects_1 = None
753return (getitem_3, tangents_1, getitem_2)""",
754)
755elif i == 1:
756self.assertExpectedInline(
757fw_graph.code.strip(),
758"""\
759def forward(self, primals_1, primals_2, primals_3, primals_4):
760with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops._mylib.foo.default, primals_2); primals_1 = primals_2 = None
761getitem = with_effects[0]
762getitem_1 = with_effects[1]; with_effects = None
763with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops._mylib.foo.default, primals_3); getitem = primals_3 = None
764getitem_2 = with_effects_1[0]
765getitem_3 = with_effects_1[1]; with_effects_1 = None
766add = torch.ops.aten.add.Tensor(getitem_1, primals_4); getitem_1 = None
767add_1 = torch.ops.aten.add.Tensor(getitem_3, primals_4); getitem_3 = primals_4 = None
768return (getitem_2, add, add_1)""",
769)
770self.assertExpectedInline(
771bw_graph.code.strip(),
772"""\
773def forward(self, tangents_1, tangents_2, tangents_token):
774with_effects_2 = torch.ops.higher_order.with_effects(tangents_token, torch.ops._mylib.foo.default, tangents_1); tangents_token = None
775getitem_4 = with_effects_2[0]
776getitem_5 = with_effects_2[1]; with_effects_2 = None
777with_effects_3 = torch.ops.higher_order.with_effects(getitem_4, torch.ops._mylib.foo.default, tangents_2); getitem_4 = None
778getitem_6 = with_effects_3[0]
779getitem_7 = with_effects_3[1]; with_effects_3 = None
780return (getitem_5, getitem_7, tangents_1, tangents_2, getitem_6)""",
781)
782else:
783raise NotImplementedError
784finally:
785_deregister_effectful_op(torch.ops._mylib.foo.default)
786
787@skipIfNoDynamoSupport
788def test_regular_effectful_op_only_in_backward(self):
789from torch._higher_order_ops.effects import (
790_deregister_effectful_op,
791_EffectType,
792_register_effectful_op,
793)
794
795_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
796try:
797
798def fn(x):
799return x.sin()
800
801def inps_fn():
802return (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)
803
804torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn())
805
806fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn())
807self.assertExpectedInline(
808fw_graph.code.strip(),
809"""\
810def forward(self, primals_1):
811sin = torch.ops.aten.sin.default(primals_1)
812return (sin, primals_1)""",
813)
814self.assertExpectedInline(
815bw_graph.code.strip(),
816"""\
817def forward(self, primals_1, tangents_1, tangents_token):
818with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1); tangents_token = primals_1 = None
819getitem = with_effects[0]
820getitem_1 = with_effects[1]; with_effects = None
821mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1); tangents_1 = getitem_1 = None
822return (mul, getitem)""",
823)
824
825def inps_fn_sc():
826return (
827TwoTensor(
828torch.tensor([1.0, 2.0, 3.0], requires_grad=True),
829torch.tensor([4.0, 5.0, 6.0], requires_grad=True),
830),
831)
832
833torch.compile(fn, backend="inductor", fullgraph=True)(*inps_fn_sc())
834fw_graph, bw_graph = get_fw_bw_graph(fn, inps_fn_sc())
835self.assertExpectedInline(
836fw_graph.code.strip(),
837"""\
838def forward(self, primals_1, primals_2):
839sin = torch.ops.aten.sin.default(primals_1)
840sin_1 = torch.ops.aten.sin.default(primals_2)
841return (sin, sin_1, primals_1, primals_2)""",
842)
843self.assertExpectedInline(
844bw_graph.code.strip(),
845"""\
846def forward(self, primals_1, primals_2, tangents_1, tangents_2, tangents_token):
847with_effects = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, primals_1); tangents_token = primals_1 = None
848getitem = with_effects[0]
849getitem_1 = with_effects[1]; with_effects = None
850with_effects_1 = torch.ops.higher_order.with_effects(getitem, torch.ops.aten.cos.default, primals_2); getitem = primals_2 = None
851getitem_2 = with_effects_1[0]
852getitem_3 = with_effects_1[1]; with_effects_1 = None
853mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_1); tangents_1 = getitem_1 = None
854mul_1 = torch.ops.aten.mul.Tensor(tangents_2, getitem_3); tangents_2 = getitem_3 = None
855return (mul, mul_1, getitem_2)""",
856)
857finally:
858_deregister_effectful_op(torch.ops.aten.cos.default)
859
860@skipIfNoDynamoSupport
861def test_regular_effectful_op_in_forward_and_backward(self):
862from torch._higher_order_ops.effects import (
863_deregister_effectful_op,
864_EffectType,
865_register_effectful_op,
866)
867
868_register_effectful_op(torch.ops.aten.cos.default, _EffectType.ORDERED)
869try:
870
871def fn(x):
872x = x.cos()
873return x.sin()
874
875inps = (torch.tensor([1.0, 2.0, 3.0], requires_grad=True),)
876torch.compile(fn, backend="inductor", fullgraph=True)(*inps)
877
878fw_graph, bw_graph = get_fw_bw_graph(fn, inps)
879self.assertExpectedInline(
880fw_graph.code.strip(),
881"""\
882def forward(self, primals_1, primals_2):
883with_effects = torch.ops.higher_order.with_effects(primals_1, torch.ops.aten.cos.default, primals_2); primals_1 = None
884getitem = with_effects[0]
885getitem_1 = with_effects[1]; with_effects = None
886sin = torch.ops.aten.sin.default(getitem_1)
887return (getitem, sin, primals_2, getitem_1)""",
888)
889self.assertExpectedInline(
890bw_graph.code.strip(),
891"""\
892def forward(self, primals_2, getitem_1, tangents_1, tangents_token):
893with_effects_1 = torch.ops.higher_order.with_effects(tangents_token, torch.ops.aten.cos.default, getitem_1); tangents_token = getitem_1 = None
894getitem_2 = with_effects_1[0]
895getitem_3 = with_effects_1[1]; with_effects_1 = None
896mul = torch.ops.aten.mul.Tensor(tangents_1, getitem_3); tangents_1 = getitem_3 = None
897sin_1 = torch.ops.aten.sin.default(primals_2); primals_2 = None
898neg = torch.ops.aten.neg.default(sin_1); sin_1 = None
899mul_1 = torch.ops.aten.mul.Tensor(mul, neg); mul = neg = None
900return (mul_1, getitem_2)""",
901)
902finally:
903_deregister_effectful_op(torch.ops.aten.cos.default)
904
905
906if __name__ == "__main__":
907run_tests()
908