1
# Owner(s): ["oncall: pt2"]
5
from unittest.mock import patch
8
import torch.utils.checkpoint
9
from functorch.compile import aot_function, min_cut_rematerialization_partition, nop
11
from torch.testing._internal.common_device_type import (
13
instantiate_device_type_tests,
16
from torch.testing._internal.common_utils import IS_CI, IS_WINDOWS, run_tests, TestCase
18
if IS_WINDOWS and IS_CI:
19
sys.stderr.write("torch.compile not supported on windows")
20
if __name__ == "__main__":
22
raise unittest.SkipTest("torch.compile not supported on windows")
25
def count_philox_rand(gm, args, freq):
26
assert [node.target for node in gm.graph.nodes].count(
27
torch.ops.rngprims.philox_rand.default
32
class TestFunctionalizationRngOps(TestCase):
33
@dtypes(torch.float32)
34
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
35
def test_rand_like(self, dtype, device):
37
a = torch.rand_like(x) * x
38
a = torch.rand_like(x) * a
41
x = torch.rand(10, device=device, dtype=dtype)
43
for seed in range(10):
44
torch.cuda.manual_seed(seed)
47
torch.cuda.manual_seed(seed)
48
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
51
self.assertEqual(ref, res)
53
@dtypes(torch.float32)
54
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
55
def test_rand_like_dynamic(self, dtype, device):
57
a = torch.rand_like(x) * x
58
a = torch.rand_like(x) * a
61
for seed in range(1, 10):
63
x = torch.rand(shape, device=device, dtype=dtype)
64
torch.cuda.manual_seed(seed)
67
torch.cuda.manual_seed(seed)
68
opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
71
self.assertEqual(ref, res)
73
@dtypes(torch.float32)
74
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
75
def test_rand_like_dynamic_bwd(self, dtype, device):
77
a = torch.rand_like(x) * x
78
a = torch.rand_like(x) * a
81
for seed in range(1, 10):
83
x = torch.rand(shape, device=device, dtype=dtype, requires_grad=True)
84
torch.cuda.manual_seed(seed)
88
torch.cuda.manual_seed(seed)
89
opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
93
self.assertEqual(ref, res)
95
@dtypes(torch.float32)
96
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
97
def test_rand(self, dtype, device):
101
a = torch.rand(*shape, device=device, dtype=dtype) * x
102
a = torch.rand(*shape, device=device, dtype=dtype) * a
105
x = torch.rand(*shape, device=device, dtype=dtype)
107
for seed in range(10):
108
torch.cuda.manual_seed(seed)
111
torch.cuda.manual_seed(seed)
112
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
115
self.assertEqual(ref, res)
117
@dtypes(torch.float32)
118
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
119
def test_autograd_function(self, dtype, device):
122
class Custom(torch.autograd.Function):
125
ctx.save_for_backward(x)
126
a = torch.rand_like(x) * x
127
a = torch.rand_like(x) * a
131
def backward(ctx, grad_out):
132
(x,) = ctx.saved_tensors
133
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
135
custom = Custom.apply
137
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
139
x_clone = x.clone().detach().requires_grad_(True)
141
torch.cuda.manual_seed(123)
145
torch.cuda.manual_seed(123)
146
fwd_compiler = functools.partial(count_philox_rand, freq=2)
147
bwd_compiler = functools.partial(count_philox_rand, freq=1)
148
aot_custom = aot_function(custom, fwd_compiler, bwd_compiler)
149
res = aot_custom(x_clone)
152
self.assertEqual(ref, res)
153
self.assertEqual(x.grad, x_clone.grad)
155
@dtypes(torch.float32)
156
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
157
def test_multiple_subgraphs(self, dtype, device):
158
# Checks that rng state is maintained when there are multiple aot traced
162
class CustomOp1(torch.autograd.Function):
165
ctx.save_for_backward(x)
166
a = torch.rand_like(x) * x
167
a = torch.rand_like(x) * a
171
def backward(ctx, grad_out):
172
(x,) = ctx.saved_tensors
173
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
175
class CustomOp2(torch.autograd.Function):
178
ctx.save_for_backward(x)
179
a = torch.rand_like(x) * x
183
def backward(ctx, grad_out):
184
(x,) = ctx.saved_tensors
185
return grad_out * torch.rand_like(grad_out) * torch.rand_like(x)
187
custom_op1 = CustomOp1.apply
188
custom_op2 = CustomOp2.apply
195
fwd_compiler = functools.partial(count_philox_rand, freq=2)
196
bwd_compiler = functools.partial(count_philox_rand, freq=1)
197
aot_custom_op1 = aot_function(custom_op1, fwd_compiler, bwd_compiler)
198
fwd_compiler = functools.partial(count_philox_rand, freq=1)
199
bwd_compiler = functools.partial(count_philox_rand, freq=2)
200
aot_custom_op2 = aot_function(custom_op2, fwd_compiler, bwd_compiler)
203
a = aot_custom_op1(x)
205
return aot_custom_op2(b)
207
for seed in range(10):
208
torch.cuda.manual_seed(seed)
209
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
210
x_clone = x.clone().detach().requires_grad_(True)
212
torch.cuda.manual_seed(seed)
216
torch.cuda.manual_seed(seed)
217
res = aot_fn(x_clone)
220
self.assertEqual(ref, res)
221
self.assertEqual(x.grad, x_clone.grad)
223
@dtypes(torch.float32)
224
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
225
def test_set_get_rng_state(self, dtype, device):
227
a = torch.rand_like(x) * x
228
state = torch.cuda.get_rng_state()
229
a = torch.rand_like(x) * a
230
torch.cuda.set_rng_state(state)
231
a = torch.rand_like(x) * a
234
x = torch.rand(10, device=device, dtype=dtype)
236
for seed in range(10):
237
torch.cuda.manual_seed(seed)
240
torch.cuda.manual_seed(seed)
241
fwd_compiler = functools.partial(count_philox_rand, freq=3)
242
aot_fn = aot_function(fn, fwd_compiler)
245
self.assertEqual(ref, res)
247
@dtypes(torch.float32)
248
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
249
def test_min_cut_partitioner(self, dtype, device):
250
# Checks that the calling convention is maintained
254
a = torch.rand_like(x) * x
255
a = torch.rand_like(x) * a
261
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
263
x_clone = x.clone().detach().requires_grad_(True)
265
torch.cuda.manual_seed(123)
269
torch.cuda.manual_seed(123)
270
fwd_compiler = functools.partial(count_philox_rand, freq=2)
271
bwd_compiler = functools.partial(count_philox_rand, freq=0)
272
aot_custom = aot_function(
276
partition_fn=min_cut_rematerialization_partition,
278
# aot_custom = aot_function(fn, fwd_compiler, bwd_compiler)
279
res = aot_custom(x_clone)
282
self.assertEqual(ref, res)
283
self.assertEqual(x.grad, x_clone.grad)
285
# TODO - Dropout needs more work because of offset calculation
286
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
287
@dtypes(torch.float32)
288
def test_checkpoint(self, dtype, device):
290
return torch.nn.functional.dropout(x, 0.6)
293
return torch.utils.checkpoint.checkpoint(g, x, y, use_reentrant=False)
295
# x = torch.rand(2, 2, device="cuda", requires_grad=True)
296
x = torch.ones(2, 2, device="cuda", requires_grad=True)
297
y = torch.rand(2, 2, device="cuda", requires_grad=True)
298
torch.cuda.manual_seed(123)
301
# With checkpointing we should recompute dropout in bwd, and philox_rand is passed from fwd
302
fwd_compiler = functools.partial(count_philox_rand, freq=1)
303
bwd_compiler = functools.partial(count_philox_rand, freq=0)
304
aot_fn = aot_function(fn, fwd_compiler, bwd_compiler)
305
# We cant check accuracy here because rand_like generated different rand numbers than dropout
309
@dtypes(torch.float32)
310
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
311
def test_dropout_decomp(self, dtype, device):
313
return torch.nn.functional.dropout(x, 0.6) * x
315
x = torch.rand(10, device=device, dtype=dtype)
317
# Ensure the decomp is happening
318
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1))
319
# We cant check accuracy here because rand_like generated different rand numbers than dropout
324
instantiate_device_type_tests(TestFunctionalizationRngOps, globals(), only_for=only_for)
327
class NegativeTest(TestCase):
328
@dtypes(torch.float32)
329
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
330
def test_on_cpu(self, dtype, device):
332
a = torch.rand_like(x) * x
333
a = torch.rand_like(x) * a
336
x = torch.rand(10, device=device, dtype=dtype)
338
aot_fn = aot_function(fn, nop)
339
with self.assertRaises(RuntimeError):
344
instantiate_device_type_tests(NegativeTest, globals(), only_for=only_for)
346
if __name__ == "__main__":