5
from torch.testing._internal.common_utils import (
10
from torch.testing._internal.common_device_type import instantiate_device_type_tests, dtypes
11
from functorch.compile import aot_function, nop, min_cut_rematerialization_partition
12
from unittest.mock import patch
14
import torch.utils.checkpoint
17
from torch.testing._internal.common_utils import (
22
if IS_WINDOWS and IS_CI:
24
"torch.compile not supported on windows"
26
if __name__ == "__main__":
28
raise unittest.SkipTest("torch.compile not supported on windows")
30
def count_philox_rand(gm, args, freq):
31
assert [node.target for node in gm.graph.nodes].count(torch.ops.rngprims.philox_rand.default) == freq
34
class TestFunctionalizationRngOps(TestCase):
35
@dtypes(torch.float32)
36
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
37
def test_rand_like(self, dtype, device):
39
a = torch.rand_like(x) * x
40
a = torch.rand_like(x) * a
43
x = torch.rand(10, device=device, dtype=dtype)
45
for seed in range(10):
46
torch.cuda.manual_seed(seed)
49
torch.cuda.manual_seed(seed)
50
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
53
self.assertEqual(ref, res)
55
@dtypes(torch.float32)
56
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
57
def test_rand_like_dynamic(self, dtype, device):
59
a = torch.rand_like(x) * x
60
a = torch.rand_like(x) * a
63
for seed in range(1, 10):
65
x = torch.rand(shape, device=device, dtype=dtype)
66
torch.cuda.manual_seed(seed)
69
torch.cuda.manual_seed(seed)
70
opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
73
self.assertEqual(ref, res)
77
@dtypes(torch.float32)
78
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
79
def test_rand_like_dynamic_bwd(self, dtype, device):
81
a = torch.rand_like(x) * x
82
a = torch.rand_like(x) * a
85
for seed in range(1, 10):
87
x = torch.rand(shape, device=device, dtype=dtype, requires_grad=True)
88
torch.cuda.manual_seed(seed)
92
torch.cuda.manual_seed(seed)
93
opt_fn = torch.compile(fn, backend="aot_eager", dynamic=True)
97
self.assertEqual(ref, res)
100
@dtypes(torch.float32)
101
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
102
def test_rand(self, dtype, device):
106
a = torch.rand(*shape, device=device, dtype=dtype) * x
107
a = torch.rand(*shape, device=device, dtype=dtype) * a
110
x = torch.rand(*shape, device=device, dtype=dtype)
112
for seed in range(10):
113
torch.cuda.manual_seed(seed)
116
torch.cuda.manual_seed(seed)
117
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=2))
120
self.assertEqual(ref, res)
122
@dtypes(torch.float32)
123
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
124
def test_autograd_function(self, dtype, device):
127
class Custom(torch.autograd.Function):
130
ctx.save_for_backward(x)
131
a = torch.rand_like(x) * x
132
a = torch.rand_like(x) * a
136
def backward(ctx, grad_out):
137
x, = ctx.saved_tensors
138
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
140
custom = Custom.apply
142
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
144
x_clone = x.clone().detach().requires_grad_(True)
146
torch.cuda.manual_seed(123)
150
torch.cuda.manual_seed(123)
151
fwd_compiler = functools.partial(count_philox_rand, freq=2)
152
bwd_compiler = functools.partial(count_philox_rand, freq=1)
153
aot_custom = aot_function(custom, fwd_compiler, bwd_compiler)
154
res = aot_custom(x_clone)
157
self.assertEqual(ref, res)
158
self.assertEqual(x.grad, x_clone.grad)
160
@dtypes(torch.float32)
161
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
162
def test_multiple_subgraphs(self, dtype, device):
167
class CustomOp1(torch.autograd.Function):
170
ctx.save_for_backward(x)
171
a = torch.rand_like(x) * x
172
a = torch.rand_like(x) * a
176
def backward(ctx, grad_out):
177
x, = ctx.saved_tensors
178
return grad_out * torch.rand_like(grad_out) * torch.cos(x)
180
class CustomOp2(torch.autograd.Function):
183
ctx.save_for_backward(x)
184
a = torch.rand_like(x) * x
188
def backward(ctx, grad_out):
189
x, = ctx.saved_tensors
190
return grad_out * torch.rand_like(grad_out) * torch.rand_like(x)
193
custom_op1 = CustomOp1.apply
194
custom_op2 = CustomOp2.apply
201
fwd_compiler = functools.partial(count_philox_rand, freq=2)
202
bwd_compiler = functools.partial(count_philox_rand, freq=1)
203
aot_custom_op1 = aot_function(custom_op1, fwd_compiler, bwd_compiler)
204
fwd_compiler = functools.partial(count_philox_rand, freq=1)
205
bwd_compiler = functools.partial(count_philox_rand, freq=2)
206
aot_custom_op2 = aot_function(custom_op2, fwd_compiler, bwd_compiler)
209
a = aot_custom_op1(x)
211
return aot_custom_op2(b)
214
for seed in range(10):
215
torch.cuda.manual_seed(seed)
216
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
217
x_clone = x.clone().detach().requires_grad_(True)
219
torch.cuda.manual_seed(seed)
223
torch.cuda.manual_seed(seed)
224
res = aot_fn(x_clone)
227
self.assertEqual(ref, res)
228
self.assertEqual(x.grad, x_clone.grad)
230
@dtypes(torch.float32)
231
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
232
def test_set_get_rng_state(self, dtype, device):
234
a = torch.rand_like(x) * x
235
state = torch.cuda.get_rng_state()
236
a = torch.rand_like(x) * a
237
torch.cuda.set_rng_state(state)
238
a = torch.rand_like(x) * a
241
x = torch.rand(10, device=device, dtype=dtype)
243
for seed in range(10):
244
torch.cuda.manual_seed(seed)
247
torch.cuda.manual_seed(seed)
248
fwd_compiler = functools.partial(count_philox_rand, freq=3)
249
aot_fn = aot_function(fn, fwd_compiler)
252
self.assertEqual(ref, res)
254
@dtypes(torch.float32)
255
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
256
def test_min_cut_partitioner(self, dtype, device):
261
a = torch.rand_like(x) * x
262
a = torch.rand_like(x) * a
269
x = torch.rand(*shape, device=device, dtype=dtype, requires_grad=True)
271
x_clone = x.clone().detach().requires_grad_(True)
273
torch.cuda.manual_seed(123)
277
torch.cuda.manual_seed(123)
278
fwd_compiler = functools.partial(count_philox_rand, freq=2)
279
bwd_compiler = functools.partial(count_philox_rand, freq=0)
280
aot_custom = aot_function(fn, fwd_compiler, bwd_compiler, partition_fn=min_cut_rematerialization_partition)
282
res = aot_custom(x_clone)
285
self.assertEqual(ref, res)
286
self.assertEqual(x.grad, x_clone.grad)
289
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
290
@dtypes(torch.float32)
291
def test_checkpoint(self, dtype, device):
293
return torch.nn.functional.dropout(x, 0.6)
296
return torch.utils.checkpoint.checkpoint(g, x, y, use_reentrant=False)
299
x = torch.ones(2, 2, device="cuda", requires_grad=True)
300
y = torch.rand(2, 2, device="cuda", requires_grad=True)
301
torch.cuda.manual_seed(123)
305
fwd_compiler = functools.partial(count_philox_rand, freq=1)
306
bwd_compiler = functools.partial(count_philox_rand, freq=1)
307
aot_fn = aot_function(fn, fwd_compiler, bwd_compiler)
312
@dtypes(torch.float32)
313
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
314
def test_dropout_decomp(self, dtype, device):
316
return torch.nn.functional.dropout(x, 0.6) * x
318
x = torch.rand(10, device=device, dtype=dtype)
321
aot_fn = aot_function(fn, functools.partial(count_philox_rand, freq=1))
327
instantiate_device_type_tests(TestFunctionalizationRngOps, globals(), only_for=only_for)
330
class NegativeTest(TestCase):
331
@dtypes(torch.float32)
332
@patch.object(torch._functorch.config, "functionalize_rng_ops", True)
333
def test_on_cpu(self, dtype, device):
335
a = torch.rand_like(x) * x
336
a = torch.rand_like(x) * a
339
x = torch.rand(10, device=device, dtype=dtype)
341
aot_fn = aot_function(fn, nop)
342
with self.assertRaises(RuntimeError):
347
instantiate_device_type_tests(NegativeTest, globals(), only_for=only_for)
349
if __name__ == "__main__":